Skip to content

Commit

Permalink
Merge pull request #857 from CliMA/gb/leaderboard
Browse files Browse the repository at this point in the history
Fix discard spinup in leaderboad and add best/worst single models
  • Loading branch information
Sbozzolo authored Jun 26, 2024
2 parents 9a9ac73 + 9102350 commit a4046f9
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 82 deletions.
26 changes: 13 additions & 13 deletions experiments/ClimaEarth/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.4"
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "afda340012fd05e6c7fd2baf7decd5033d3e5be2"
project_hash = "21c747b53577d075d3bfa396efb04bb9b1e1c965"

[[deps.ADTypes]]
git-tree-sha1 = "fa0822e5baee6e23081c2685ae27265dabee23d8"
git-tree-sha1 = "3a6511b6e54550bcbc986c560921a8cd7761fcd8"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
version = "1.4.0"
version = "1.5.1"
weakdeps = ["ChainRulesCore", "EnzymeCore"]

[deps.ADTypes.extensions]
Expand Down Expand Up @@ -119,9 +119,9 @@ version = "7.11.0"

[[deps.ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra"]
git-tree-sha1 = "420e2853770f50e5306b9d96b5a66f26e7c73bc6"
git-tree-sha1 = "600078184f7de14b3e60efe13fc0ba5c59f6dca5"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
version = "1.9.4"
version = "1.10.0"
weakdeps = ["SparseArrays"]

[deps.ArrayLayouts.extensions]
Expand Down Expand Up @@ -311,10 +311,10 @@ weakdeps = ["SparseArrays"]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[[deps.ClimaAnalysis]]
deps = ["NCDatasets", "OrderedCollections", "Statistics"]
git-tree-sha1 = "c2e1c0d5c30a2519a4282988037b255dbc9aee00"
deps = ["NCDatasets", "OrderedCollections", "Reexport", "Statistics"]
git-tree-sha1 = "69c740df5906f48a5739588d7dadf772311d8b7d"
uuid = "29b5916a-a76c-4e73-9657-3c8fd22e65e6"
version = "0.5.3"
version = "0.5.4"
weakdeps = ["CairoMakie", "GeoMakie"]

[deps.ClimaAnalysis.extensions]
Expand Down Expand Up @@ -1739,9 +1739,9 @@ version = "2024.1.0+0"

[[deps.MPI]]
deps = ["Distributed", "DocStringExtensions", "Libdl", "MPICH_jll", "MPIPreferences", "MPItrampoline_jll", "MicrosoftMPI_jll", "OpenMPI_jll", "PkgVersion", "PrecompileTools", "Requires", "Serialization", "Sockets"]
git-tree-sha1 = "4e3136db3735924f96632a5b40a5979f1f53fa07"
git-tree-sha1 = "14cef41baf5b675b192b02a22c710f725ab333a7"
uuid = "da04e1cc-30fd-572f-bb4f-1f8673147195"
version = "0.20.19"
version = "0.20.20"

[deps.MPI.extensions]
AMDGPUExt = "AMDGPU"
Expand Down Expand Up @@ -2830,9 +2830,9 @@ uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.24"

[[deps.TranscodingStreams]]
git-tree-sha1 = "a947ea21087caba0a798c5e494d0bb78e3a1a3a0"
git-tree-sha1 = "d73336d81cafdc277ff45558bb7eaa2b04a8e472"
uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
version = "0.10.9"
version = "0.10.10"
weakdeps = ["Random", "Test"]

[deps.TranscodingStreams.extensions]
Expand Down
2 changes: 1 addition & 1 deletion experiments/ClimaEarth/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
ArgParse = "1.1"
ArtifactWrappers = "0.2"
AtmosphericProfilesLibrary = "0.1"
ClimaAnalysis = "0.5.3"
ClimaAnalysis = "0.5.4"
ClimaAtmos = "0.26"
ClimaCorePlots = "0.2"
ClimaLand = "0.12"
Expand Down
26 changes: 17 additions & 9 deletions experiments/ClimaEarth/run_amip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -925,32 +925,40 @@ if ClimaComms.iamroot(comms_ctx)
## Compare against observations
if t_end > 84600 && config_dict["output_default_diagnostics"]
@info "Error against observations"
diagnostics_times = copy(atmos_sim.integrator.sol.t)
include("user_io/leaderboard.jl")
ClimaAnalysis = Leaderboard.ClimaAnalysis

compare_vars = ["pr", "rsut", "rlut"]
diagnostics_folder_path = atmos_sim.integrator.p.output_dir
leaderboard_base_path = dir_paths.artifacts

first_var = get(ClimaAnalysis.SimDir(diagnostics_folder_path), short_name = first(compare_vars))

diagnostics_times = ClimaAnalysis.times(first_var)
# Remove the first `spinup_months` months from the leaderboard
spinup_months = 6
spinup_cutoff = spinup_months * 30 * 86400.0
if t_end > spinup_cutoff
filter!(x -> x < spinup_cutoff, diagnostics_times)
if diagnostics_times[end] > spinup_cutoff
filter!(x -> x > spinup_cutoff, diagnostics_times)
end

output_dates = cs.dates.date0[] .+ Dates.Second.(diagnostics_times)
output_dates = Dates.DateTime(first_var.attributes["start_date"]) .+ Dates.Second.(diagnostics_times)

@info "Working with dates:"
@info output_dates

include("user_io/leaderboard.jl")
compare_vars = ["pr", "rsut", "rlut"]
function compute_biases(dates)
if isempty(dates)
return map(x -> 0.0, compare_vars)
else
return Leaderboard.compute_biases(atmos_sim.integrator.p.output_dir, compare_vars, dates)
return Leaderboard.compute_biases(diagnostics_folder_path, compare_vars, dates)
end
end

function plot_biases(dates, biases, output_name)
isempty(dates) && return nothing

output_path = joinpath(dir_paths.artifacts, "bias_$(output_name).png")
output_path = joinpath(leaderboard_base_path, "bias_$(output_name).png")
Leaderboard.plot_biases(biases; output_path)
end

Expand Down Expand Up @@ -981,7 +989,7 @@ if ClimaComms.iamroot(comms_ctx)
1:length(compare_vars),
)

Leaderboard.plot_leaderboard(rmses; output_path = joinpath(dir_paths.artifacts, "bias_leaderboard.png"))
Leaderboard.plot_leaderboard(rmses; output_path = joinpath(leaderboard_base_path, "bias_leaderboard.png"))
end
end

Expand Down
144 changes: 89 additions & 55 deletions experiments/ClimaEarth/user_io/leaderboard/cmip_rmse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,67 +35,101 @@ function best_single_model(RMSEs)
end

"""
RSME_stats(RMSEs)
worst_single_model(RMSEs)
Return the one model that has the overall largest error.
"""
function worst_single_model(RMSEs)
_, index = findmax(r -> abs.(values(r)), RMSEs)
return RMSEs[index]
end

"""
RMSE_stats(RMSEs)
RMSEs is the dictionary OTHER_MODELS_RMSEs.
Return:
- best single model
- worst single model
- "model" with all the medians
- "model" with all the best values
- "model" with all the worst values
"""
function RSME_stats(vecRMSEs)
# Collect into vectors that we can process independently
all_values = stack(values.(vecRMSEs))
ANN, DJF, JJA, MAM, SON = ntuple(i -> all_values[i, :], 5)

median_model = RMSEs(;
model_name = "Median",
ANN = median(ANN),
DJF = median(DJF),
JJA = median(JJA),
MAM = median(MAM),
SON = median(SON),
)

worst_model = RMSEs(;
model_name = "Worst",
ANN = maximum(abs.(ANN)),
DJF = maximum(abs.(DJF)),
JJA = maximum(abs.(JJA)),
MAM = maximum(abs.(MAM)),
SON = maximum(abs.(SON)),
)

best_model = RMSEs(;
model_name = "Best",
ANN = minimum(abs.(ANN)),
DJF = minimum(abs.(DJF)),
JJA = minimum(abs.(JJA)),
MAM = minimum(abs.(MAM)),
SON = minimum(abs.(SON)),
)

quantile25 = RMSEs(;
model_name = "Quantile 0.25",
ANN = quantile(ANN, 0.25),
DJF = quantile(DJF, 0.25),
JJA = quantile(JJA, 0.25),
MAM = quantile(MAM, 0.25),
SON = quantile(SON, 0.25),
)

quantile75 = RMSEs(;
model_name = "Quantile 0.75",
ANN = quantile(ANN, 0.75),
DJF = quantile(DJF, 0.75),
JJA = quantile(JJA, 0.75),
MAM = quantile(MAM, 0.75),
SON = quantile(SON, 0.75),
)

(; best_single_model = best_single_model(vecRMSEs), median_model, worst_model, best_model, quantile25, quantile75)
end
function RMSE_stats(dict_vecRMSEs)
stats = Dict()
# cumulative_error maps model_names with the total RMSE across metrics normalized by median(RMSE)
cumulative_error = Dict()
for (key, vecRMSEs) in dict_vecRMSEs
# Collect into vectors that we can process independently
all_values = stack(values.(vecRMSEs))
ANN, DJF, JJA, MAM, SON = ntuple(i -> all_values[i, :], 5)

median_model = RMSEs(;
model_name = "Median",
ANN = median(ANN),
DJF = median(DJF),
JJA = median(JJA),
MAM = median(MAM),
SON = median(SON),
)

worst_model = RMSEs(;
model_name = "Worst",
ANN = maximum(abs.(ANN)),
DJF = maximum(abs.(DJF)),
JJA = maximum(abs.(JJA)),
MAM = maximum(abs.(MAM)),
SON = maximum(abs.(SON)),
)

best_model = RMSEs(;
model_name = "Best",
ANN = minimum(abs.(ANN)),
DJF = minimum(abs.(DJF)),
JJA = minimum(abs.(JJA)),
MAM = minimum(abs.(MAM)),
SON = minimum(abs.(SON)),
)

quantile25 = RMSEs(;
model_name = "Quantile 0.25",
ANN = quantile(ANN, 0.25),
DJF = quantile(DJF, 0.25),
JJA = quantile(JJA, 0.25),
MAM = quantile(MAM, 0.25),
SON = quantile(SON, 0.25),
)

quantile75 = RMSEs(;
model_name = "Quantile 0.75",
ANN = quantile(ANN, 0.75),
DJF = quantile(DJF, 0.75),
JJA = quantile(JJA, 0.75),
MAM = quantile(MAM, 0.75),
SON = quantile(SON, 0.75),
)

for rmse in vecRMSEs
haskey(cumulative_error, cumulative_error) || (cumulative_error[rmse.model_name] = 0.0)
cumulative_error[rmse.model_name] += sum(values(rmse) ./ values(median_model))
end

for short_name in short_names
COMPARISON_RMSEs[short_name] = RSME_stats(OTHER_MODELS_RMSEs[short_name])
stats[key] = (;
best_single_model = best_single_model(vecRMSEs),
worst_single_model = worst_single_model(vecRMSEs),
median_model,
worst_model,
best_model,
quantile25,
quantile75,
)
end

_, absolute_best_model = findmin(cumulative_error)
_, absolute_worst_model = findmax(cumulative_error)

return (; stats, absolute_best_model, absolute_worst_model)
end

const COMPARISON_RMSEs_STATS = RMSE_stats(OTHER_MODELS_RMSEs)
17 changes: 13 additions & 4 deletions experiments/ClimaEarth/user_io/leaderboard/compare_with_obs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import CairoMakie
const OBS_DS = Dict()
const SIM_DS_KWARGS = Dict()
const OTHER_MODELS_RMSEs = Dict()
const COMPARISON_RMSEs = Dict()

function preprocess_pr_fn(data)
# -1 kg/m/s2 -> 1 mm/day
Expand Down Expand Up @@ -167,6 +166,8 @@ function plot_leaderboard(rmses; output_path)
# models compared, and there is one row per variable
squares = zeros(NUM_BOXES * NUM_MODELS, num_variables)

(; absolute_best_model, absolute_worst_model) = COMPARISON_RMSEs_STATS

for (var_num, rmse) in enumerate(rmses)
short_name = rmse.ANN.attributes["var_short_name"]
units = rmse.ANN.attributes["units"]
Expand All @@ -178,7 +179,11 @@ function plot_leaderboard(rmses; output_path)
)

# Against other models
(; best_single_model, median_model, worst_model, best_model) = COMPARISON_RMSEs[short_name]

(; median_model) = COMPARISON_RMSEs_STATS.stats[short_name]

best_single_model = first(filter(x -> x.model_name == absolute_best_model, OTHER_MODELS_RMSEs[short_name]))
worst_single_model = first(filter(x -> x.model_name == absolute_worst_model, OTHER_MODELS_RMSEs[short_name]))

squares[begin:NUM_BOXES, end - var_num + 1] .= values(rmse) ./ values(median_model)
squares[(NUM_BOXES + 1):end, end - var_num + 1] .= values(best_single_model) ./ values(median_model)
Expand All @@ -190,7 +195,8 @@ function plot_leaderboard(rmses; output_path)
label = median_model.model_name,
color = :black,
marker = :hline,
markersize = 15,
markersize = 10,
visible = false,
)

categories = vcat(map(_ -> collect(1:5), 1:length(OTHER_MODELS_RMSEs[short_name]))...)
Expand All @@ -206,6 +212,9 @@ function plot_leaderboard(rmses; output_path)
whiskerlinewidth = 1,
)

CairoMakie.scatter!(ax, 1:5, values(best_single_model), label = absolute_best_model)
CairoMakie.scatter!(ax, 1:5, values(worst_single_model), label = absolute_worst_model)

# If we want to plot other models
# for model in OTHER_MODELS_RMSEs[short_name]
# CairoMakie.scatter!(ax, 1:5, values(model), marker = :hline)
Expand All @@ -218,7 +227,7 @@ function plot_leaderboard(rmses; output_path)
label = rmse.model_name,
marker = :star5,
markersize = 20,
color = :orange,
color = :green,
)

# Add a fake extra point to center the legend a little better
Expand Down

0 comments on commit a4046f9

Please sign in to comment.