Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganize restart tests #3351

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 143 additions & 102 deletions test/restart.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import ClimaCore.Spaces: AbstractSpace
import ClimaComms
pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends
import Logging
import NCDatasets
using Test

import Random
Expand All @@ -35,10 +36,6 @@ ClimaComms.init(comms_ctx)
# different.
#
# For this reason, we don't use Test but just print to screen the differences.
# However, we still have to return an exit code with failure in case of the
# comparison fails. So, we have this global `SUCCESS` bool that is updated by
# the result of tests.
const SUCCESS::Base.RefValue{Bool} = Ref(true)

"""
_error(arr1::AbstractArray, arr2::AbstractArray; ABS_TOL = 100eps(eltype(arr1)))
Expand Down Expand Up @@ -130,6 +127,17 @@ function _compare(v1::T, v2::T; name, ignore) where {T <: Number}
return print_maybe(v1 === v2, "$name differs: $v1 vs $v2")
end

# We ignore NCDatasets. They contain a lot of state-ful information
function _compare(
pass,
v1::T,
v2::T;
name,
ignore,
) where {T <: NCDatasets.NCDataset}
return pass
end

function _compare(
v1::T,
v2::T;
Expand All @@ -143,6 +151,17 @@ function _compare(pass, v1::T, v2::T; name, ignore) where {T <: AbstractData}
return pass && _compare(parent(v1), parent(v2); name, ignore)
end

# Handle views
function _compare(
pass,
v1::SubArray{FT},
v2::SubArray{FT};
name,
ignore,
) where {FT <: AbstractFloat}
return pass && _compare(collect(v1), collect(v2); name, ignore)
end

function _compare(
v1::AbstractArray{FT},
v2::AbstractArray{FT};
Expand All @@ -167,6 +186,118 @@ end
# Disable all the @info statements that are produced when creating a simulation
Logging.disable_logging(Logging.Info)


"""
test_restart(test_dict; job_id, comms_ctx, more_ignore = Symbol[])

Test if the restarts are consistent for a simulation defined by the `test_dict` config.

`more_ignore` is a Vector of Symbols that identifies config-specific keys that
have to be ignored when reading a simulation.
"""
function test_restart(test_dict; job_id, comms_ctx, more_ignore = Symbol[])
println("job_id = $(job_id)")

local_success = Ref(true)

config = CA.AtmosConfig(test_dict; job_id, comms_ctx)

simulation = CA.get_simulation(config)
CA.solve_atmos!(simulation)

# Check re-importing the same state
restart_dir = simulation.output_dir
@test isfile(joinpath(restart_dir), "day0.3.hdf5")

# Reset random seed for RRTMGP
Random.seed!(1234)

ClimaComms.iamroot(comms_ctx) && println(" just reading data")
config_should_be_same = CA.AtmosConfig(
merge(test_dict, Dict("detect_restart_file" => true));
job_id,
comms_ctx,
)

simulation_restarted = CA.get_simulation(config_should_be_same)

local_success[] &= compare(
simulation.integrator.u,
simulation_restarted.integrator.u;
name = "integrator.u",
)
local_success[] &= compare(
axes(simulation.integrator.u.c),
axes(simulation_restarted.integrator.u.c);
name = "space",
)
local_success[] &= compare(
simulation.integrator.p,
simulation_restarted.integrator.p;
name = "integrator.p",
ignore = Set([
:ghost_buffer,
:hyperdiffusion_ghost_buffer,
:scratch,
:output_dir,
:ghost_buffer,
# Computed in tendencies (which are not computed in this case)
:hyperdiff,
:precipitation,
# rc is some CUDA/CuArray internal object that we don't care about
:rc,
# DataHandlers contains caches, so they are stateful
:data_handler,
# Config-specific
more_ignore...,
]),
)

# Check re-importing from previous state and advancing one step
ClimaComms.iamroot(comms_ctx) && println(" reading and simulating")
# Reset random seed for RRTMGP
Random.seed!(1234)

restart_file = joinpath(simulation.output_dir, "day0.2.hdf5")
@test isfile(joinpath(restart_dir), "day0.2.hdf5")
# Restart from specific file
config2 = CA.AtmosConfig(
merge(test_dict, Dict("restart_file" => restart_file));
job_id,
comms_ctx,
)

simulation_restarted2 = CA.get_simulation(config2)
CA.fill_with_nans!(simulation_restarted2.integrator.p)

CA.solve_atmos!(simulation_restarted2)
local_success[] &= compare(
simulation.integrator.u,
simulation_restarted2.integrator.u;
name = "integrator.u",
)
local_success[] &= compare(
simulation.integrator.p,
simulation_restarted2.integrator.p;
name = "integrator.p",
ignore = Set([
:scratch,
:output_dir,
:ghost_buffer,
:hyperdiffusion_ghost_buffer,
:data_handler,
:rc,
]),
)

return local_success[]
end

# Let's prepare the test_dicts. TESTING is a Vector of NamedTuples, each element
# has a test_dict, a job_id, and a more_ignore

TESTING = Any[]

if comms_ctx isa ClimaComms.SingletonCommsContext
configurations = ["sphere", "box", "column"]
else
Expand All @@ -180,13 +311,13 @@ for configuration in configurations
topography = "Earth"
turbconv_models = [nothing, "diagnostic_edmfx"]
# turbconv_models = ["prognostic_edmfx"]
radiations = [nothing]
radiations = [nothing, "gray"]
else
moistures = ["equil"]
precips = ["1M"]
topography = "NoWarp"
turbconv_models = ["diagnostic_edmfx"]
radiations = [nothing]
radiations = [nothing, "gray"]
end

for turbconv_mode in turbconv_models
Expand All @@ -200,9 +331,6 @@ for configuration in configurations
end
end

println(
"config = $configuration $moisture $precip $topography $radiation",
)
# The `enable_bubble` case is broken for ClimaCore < 0.14.6, so we
# hard-code this to be always false for those versions
bubble = pkgversion(ClimaCore) > v"0.14.5"
Expand All @@ -211,9 +339,8 @@ for configuration in configurations
output_loc =
ClimaComms.iamroot(comms_ctx) ? mktempdir(pwd()) : ""
output_loc = ClimaComms.bcast(comms_ctx, output_loc)
ClimaComms.barrier(comms_ctx)

job_id = "restart"
job_id = "$(configuration)_$(moisture)_$(precip)_$(topography)_$(radiation)"
test_dict = Dict(
"test_dycore_consistency" => true, # We will add NaNs to the cache, just to make sure
"check_nan_every" => 3,
Expand All @@ -240,103 +367,17 @@ for configuration in configurations
)
more_ignore = Symbol[]

config = CA.AtmosConfig(test_dict; job_id, comms_ctx)

simulation = CA.get_simulation(config)
CA.solve_atmos!(simulation)

# Check re-importing the same state
restart_dir = simulation.output_dir
@test isfile(joinpath(restart_dir), "day0.3.hdf5")

# Reset random seed for RRTMGP
Random.seed!(1234)

println(" just reading data")
if turbconv_mode == "prognostic_edmf"
more_ignore = [:ᶠnh_pressure₃ʲs]
end

config_should_be_same = CA.AtmosConfig(
merge(test_dict, Dict("detect_restart_file" => true));
job_id,
comms_ctx,
)

simulation_restarted =
CA.get_simulation(config_should_be_same)

SUCCESS[] &= compare(
simulation.integrator.u,
simulation_restarted.integrator.u;
name = "integrator.u",
)
SUCCESS[] &= compare(
axes(simulation.integrator.u.c),
axes(simulation_restarted.integrator.u.c);
name = "space",
)
SUCCESS[] &= compare(
simulation.integrator.p,
simulation_restarted.integrator.p;
name = "integrator.p",
ignore = Set([
:ghost_buffer,
:hyperdiffusion_ghost_buffer,
:scratch,
:output_dir,
:ghost_buffer,
# Computed in tendencies (which are not computed in this case)
:hyperdiff,
:precipitation,
# rc is some CUDA/CuArray internal object that we don't care about
:rc,
# Config-specific
more_ignore...,
]),
)

# Check re-importing from previous state and advancing one step
println(" reading and simulating")
# Reset random seed for RRTMGP
Random.seed!(1234)

restart_file =
joinpath(simulation.output_dir, "day0.2.hdf5")
@test isfile(joinpath(restart_dir), "day0.2.hdf5")
# Restart from specific file
config2 = CA.AtmosConfig(
merge(test_dict, Dict("restart_file" => restart_file));
job_id,
comms_ctx,
)

simulation_restarted2 = CA.get_simulation(config2)
CA.fill_with_nans!(simulation_restarted2.integrator.p)

CA.solve_atmos!(simulation_restarted2)
SUCCESS[] &= compare(
simulation.integrator.u,
simulation_restarted2.integrator.u;
name = "integrator.u",
)
SUCCESS[] &= compare(
simulation.integrator.p,
simulation_restarted2.integrator.p;
name = "integrator.p",
ignore = Set([
:scratch,
:output_dir,
:ghost_buffer,
:hyperdiffusion_ghost_buffer,
:rc,
]),
)
push!(TESTING, (; test_dict, job_id, more_ignore))
end
end
end
end
end

# Ensure that we have the correct exit code
@test SUCCESS[]
@test all(
@time test_restart(t.test_dict; comms_ctx, t.job_id, t.more_ignore) for
t in TESTING
)
Loading