Skip to content

Commit

Permalink
allreduce GC
Browse files Browse the repository at this point in the history
  • Loading branch information
simonbyrne committed Sep 21, 2022
1 parent fa74082 commit 4ae134e
Showing 1 changed file with 79 additions and 72 deletions.
151 changes: 79 additions & 72 deletions examples/hybrid/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ function get_callbacks(parsed_args, simulation, model_spec, params)
FT = eltype(params)
(; dt) = simulation

callback_filters = call_every_n_steps(affect_filter!; skip_first = true)
callback_filters = call_every_n_steps(affect_filter!; skip_first=true)
tc_callbacks =
call_every_n_steps(turb_conv_affect_filter!; skip_first = true)
call_every_n_steps(turb_conv_affect_filter!; skip_first=true)

additional_callbacks = if !isnothing(model_spec.radiation_model)
# TODO: better if-else criteria?
Expand Down Expand Up @@ -71,19 +71,19 @@ function get_callbacks(parsed_args, simulation, model_spec, params)
)
end

function call_every_n_steps(f!, n = 1; skip_first = false, call_at_end = false)
function call_every_n_steps(f!, n=1; skip_first=false, call_at_end=false)
previous_step = Ref(0)
return ODE.DiscreteCallback(
(u, t, integrator) ->
(previous_step[] += 1) % n == 0 ||
(call_at_end && t == integrator.sol.prob.tspan[2]),
f!;
initialize = (cb, u, t, integrator) -> skip_first || f!(integrator),
save_positions = (false, false),
initialize=(cb, u, t, integrator) -> skip_first || f!(integrator),
save_positions=(false, false)
)
end

function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
function call_every_dt(f!, dt; skip_first=false, call_at_end=false)
next_t = Ref{typeof(dt)}()
affect! = function (integrator)
t = integrator.t
Expand All @@ -98,13 +98,13 @@ function call_every_dt(f!, dt; skip_first = false, call_at_end = false)
return ODE.DiscreteCallback(
(u, t, integrator) -> t >= next_t[],
affect!;
initialize = (cb, u, t, integrator) -> begin
initialize=(cb, u, t, integrator) -> begin
skip_first || f!(integrator)
t_end = integrator.sol.prob.tspan[2]
next_t[] =
(call_at_end && t < t_end) ? min(t_end, t + dt) : t + dt
end,
save_positions = (false, false),
save_positions=(false, false)
)
end

Expand Down Expand Up @@ -190,11 +190,11 @@ function save_to_disk_func(integrator)
Spaces.weighted_dss!(ᶜvort)

dry_diagnostic = (;
pressure = ᶜp,
temperature = ᶜT,
potential_temperature = ᶜθ,
kinetic_energy = ᶜK,
vorticity = ᶜvort,
pressure=ᶜp,
temperature=ᶜT,
potential_temperature=ᶜθ,
kinetic_energy=ᶜK,
vorticity=ᶜvort
)

# cloudwater (liquid and ice), watervapor and RH for moist simulation
Expand All @@ -207,10 +207,10 @@ function save_to_disk_func(integrator)
ᶜRH = @. TD.relative_humidity(thermo_params, ᶜts)

moist_diagnostic = (;
cloud_liquid = ᶜcloud_liquid,
cloud_ice = ᶜcloud_ice,
water_vapor = ᶜwatervapor,
relative_humidity = ᶜRH,
cloud_liquid=ᶜcloud_liquid,
cloud_ice=ᶜcloud_ice,
water_vapor=ᶜwatervapor,
relative_humidity=ᶜRH
)
# precipitation
if :ᶜS_ρq_tot in propertynames(p)
Expand All @@ -232,9 +232,9 @@ function save_to_disk_func(integrator)

moist_diagnostic = (
moist_diagnostic...,
precipitation_removal = ᶜS_ρq_tot,
column_integrated_rain = col_integrated_rain,
column_integrated_snow = col_integrated_snow,
precipitation_removal=ᶜS_ρq_tot,
column_integrated_rain=col_integrated_rain,
column_integrated_snow=col_integrated_snow,
)
end
else
Expand All @@ -246,51 +246,51 @@ function save_to_disk_func(integrator)
tc_cent(p) = p.edmf_cache.aux.cent.turbconv
tc_face(p) = p.edmf_cache.aux.face.turbconv
turbulence_convection_diagnostic = (;
bulk_up_area = tc_cent(p).bulk.area,
bulk_up_h_tot = tc_cent(p).bulk.h_tot,
bulk_up_buoyancy = tc_cent(p).bulk.buoy,
bulk_up_q_tot = tc_cent(p).bulk.q_tot,
bulk_up_q_liq = tc_cent(p).bulk.q_liq,
bulk_up_q_ice = tc_cent(p).bulk.q_ice,
bulk_up_temperature = tc_cent(p).bulk.T,
bulk_up_cloud_fraction = tc_cent(p).bulk.cloud_fraction,
bulk_up_e_tot_tendency_precip_formation = tc_cent(
bulk_up_area=tc_cent(p).bulk.area,
bulk_up_h_tot=tc_cent(p).bulk.h_tot,
bulk_up_buoyancy=tc_cent(p).bulk.buoy,
bulk_up_q_tot=tc_cent(p).bulk.q_tot,
bulk_up_q_liq=tc_cent(p).bulk.q_liq,
bulk_up_q_ice=tc_cent(p).bulk.q_ice,
bulk_up_temperature=tc_cent(p).bulk.T,
bulk_up_cloud_fraction=tc_cent(p).bulk.cloud_fraction,
bulk_up_e_tot_tendency_precip_formation=tc_cent(
p,
).bulk.e_tot_tendency_precip_formation,
bulk_up_qt_tendency_precip_formation = tc_cent(
bulk_up_qt_tendency_precip_formation=tc_cent(
p,
).bulk.qt_tendency_precip_formation,
env_w = tc_cent(p).en.w,
env_area = tc_cent(p).en.area,
env_q_tot = tc_cent(p).en.q_tot,
env_q_liq = tc_cent(p).en.q_liq,
env_q_ice = tc_cent(p).en.q_ice,
env_theta_liq_ice = tc_cent(p).en.θ_liq_ice,
env_theta_virt = tc_cent(p).en.θ_virt,
env_theta_dry = tc_cent(p).en.θ_dry,
env_e_tot = tc_cent(p).en.e_tot,
env_e_kin = tc_cent(p).en.e_kin,
env_h_tot = tc_cent(p).en.h_tot,
env_RH = tc_cent(p).en.RH,
env_s = tc_cent(p).en.s,
env_temperature = tc_cent(p).en.T,
env_buoyancy = tc_cent(p).en.buoy,
env_cloud_fraction = tc_cent(p).en.cloud_fraction,
env_TKE = tc_cent(p).en.tke,
env_Hvar = tc_cent(p).en.Hvar,
env_QTvar = tc_cent(p).en.QTvar,
env_HQTcov = tc_cent(p).en.HQTcov,
env_e_tot_tendency_precip_formation = tc_cent(
env_w=tc_cent(p).en.w,
env_area=tc_cent(p).en.area,
env_q_tot=tc_cent(p).en.q_tot,
env_q_liq=tc_cent(p).en.q_liq,
env_q_ice=tc_cent(p).en.q_ice,
env_theta_liq_ice=tc_cent(p).en.θ_liq_ice,
env_theta_virt=tc_cent(p).en.θ_virt,
env_theta_dry=tc_cent(p).en.θ_dry,
env_e_tot=tc_cent(p).en.e_tot,
env_e_kin=tc_cent(p).en.e_kin,
env_h_tot=tc_cent(p).en.h_tot,
env_RH=tc_cent(p).en.RH,
env_s=tc_cent(p).en.s,
env_temperature=tc_cent(p).en.T,
env_buoyancy=tc_cent(p).en.buoy,
env_cloud_fraction=tc_cent(p).en.cloud_fraction,
env_TKE=tc_cent(p).en.tke,
env_Hvar=tc_cent(p).en.Hvar,
env_QTvar=tc_cent(p).en.QTvar,
env_HQTcov=tc_cent(p).en.HQTcov,
env_e_tot_tendency_precip_formation=tc_cent(
p,
).en.e_tot_tendency_precip_formation,
env_qt_tendency_precip_formation = tc_cent(
env_qt_tendency_precip_formation=tc_cent(
p,
).en.qt_tendency_precip_formation,
env_Hvar_rain_dt = tc_cent(p).en.Hvar_rain_dt,
env_QTvar_rain_dt = tc_cent(p).en.QTvar_rain_dt,
env_HQTcov_rain_dt = tc_cent(p).en.HQTcov_rain_dt,
face_bulk_w = tc_face(p).bulk.w,
face_env_w = tc_face(p).en.w,
env_Hvar_rain_dt=tc_cent(p).en.Hvar_rain_dt,
env_QTvar_rain_dt=tc_cent(p).en.QTvar_rain_dt,
env_HQTcov_rain_dt=tc_cent(p).en.HQTcov_rain_dt,
face_bulk_w=tc_face(p).bulk.w,
face_env_w=tc_face(p).en.w
)
else
turbulence_convection_diagnostic = NamedTuple()
Expand All @@ -299,9 +299,9 @@ function save_to_disk_func(integrator)
if vert_diff
(; dif_flux_uₕ, dif_flux_energy, dif_flux_ρq_tot) = p
vert_diff_diagnostic = (;
sfc_flux_momentum = dif_flux_uₕ,
sfc_flux_energy = dif_flux_energy,
sfc_evaporation = dif_flux_ρq_tot,
sfc_flux_momentum=dif_flux_uₕ,
sfc_flux_energy=dif_flux_energy,
sfc_evaporation=dif_flux_ρq_tot
)
else
vert_diff_diagnostic = NamedTuple()
Expand All @@ -311,36 +311,36 @@ function save_to_disk_func(integrator)
(; face_lw_flux_dn, face_lw_flux_up, face_sw_flux_dn, face_sw_flux_up) =
p.rrtmgp_model
rad_diagnostic = (;
lw_flux_down = RRTMGPI.array2field(FT.(face_lw_flux_dn), axes(Y.f)),
lw_flux_up = RRTMGPI.array2field(FT.(face_lw_flux_up), axes(Y.f)),
sw_flux_down = RRTMGPI.array2field(FT.(face_sw_flux_dn), axes(Y.f)),
sw_flux_up = RRTMGPI.array2field(FT.(face_sw_flux_up), axes(Y.f)),
lw_flux_down=RRTMGPI.array2field(FT.(face_lw_flux_dn), axes(Y.f)),
lw_flux_up=RRTMGPI.array2field(FT.(face_lw_flux_up), axes(Y.f)),
sw_flux_down=RRTMGPI.array2field(FT.(face_sw_flux_dn), axes(Y.f)),
sw_flux_up=RRTMGPI.array2field(FT.(face_sw_flux_up), axes(Y.f))
)
if model_spec.radiation_model isa
RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics
(;
face_clear_lw_flux_dn,
face_clear_lw_flux_up,
face_clear_sw_flux_dn,
face_clear_sw_flux_up,
face_clear_sw_flux_up
) = p.rrtmgp_model
rad_clear_diagnostic = (;
clear_lw_flux_down = RRTMGPI.array2field(
clear_lw_flux_down=RRTMGPI.array2field(
FT.(face_clear_lw_flux_dn),
axes(Y.f),
),
clear_lw_flux_up = RRTMGPI.array2field(
clear_lw_flux_up=RRTMGPI.array2field(
FT.(face_clear_lw_flux_up),
axes(Y.f),
),
clear_sw_flux_down = RRTMGPI.array2field(
clear_sw_flux_down=RRTMGPI.array2field(
FT.(face_clear_sw_flux_dn),
axes(Y.f),
),
clear_sw_flux_up = RRTMGPI.array2field(
clear_sw_flux_up=RRTMGPI.array2field(
FT.(face_clear_sw_flux_up),
axes(Y.f),
),
)
)
else
rad_clear_diagnostic = NamedTuple()
Expand Down Expand Up @@ -392,7 +392,14 @@ function save_restart_func(integrator)
end

function gc_func(integrator)
@info "Calling GC" "free mem (MB)"=Sys.free_memory()/2^20 "total mem (MB)"=Sys.total_memory()/2^20
GC.gc()
free_mem = Sys.free_memory()
total_mem = Sys.total_memory()
p_free_mem = free_mem / total_mem
min_p_free_mem = MPI.Allreduce(p_free_mem, min, comms_ctx.mpicomm)
do_gc = min_p_free_mem < 0.2
@info "GC check" "free mem (MB)" = free_mem / 2^20 "total mem (MB)" = total_mem / 2^20 "Minimum free memory (%)" = min_p_free_mem * 100 "Calling GC" = do_gc
if do_gc
GC.gc()
end
return nothing
end

0 comments on commit 4ae134e

Please sign in to comment.