Skip to content

Commit

Permalink
Cache: hyperdiffiusion
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Oct 27, 2023
1 parent 71da309 commit e82cf76
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/cache/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ function build_cache(Y, atmos, params, surface_setup, simulation)
env_thermo_quad = SGSQuadrature(FT),
precomputed_quantities(Y, atmos)...,
scratch = temporary_quantities(Y, atmos),
hyperdiffusion_cache(Y, atmos, do_dss)...,
numerics,
)
set_precomputed_quantities!(Y, default_cache, FT(0))
Expand All @@ -85,6 +84,7 @@ function build_cache(Y, atmos, params, surface_setup, simulation)

return merge(
(;
hyperdiff = hyperdiffusion_cache(Y, atmos),
rayleigh_sponge = rayleigh_sponge_cache(Y, atmos),
viscous_sponge = viscous_sponge_cache(Y, atmos),
precipitation = precipitation_cache(Y, atmos),
Expand Down
7 changes: 5 additions & 2 deletions src/cache/temporary_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ using ClimaCore.Utilities: half
# between function calls.
function temporary_quantities(Y, atmos)
center_space, face_space = axes(Y.c), axes(Y.f)

FT = Spaces.undertype(center_space)
n = n_mass_flux_subdomains(atmos.turbconv_model)
return (;
ᶠtemp_scalar = Fields.Field(FT, face_space), # ᶠp, ᶠρK_E
ᶜtemp_scalar = Fields.Field(FT, center_space), # ᶜ1
Expand All @@ -25,7 +25,10 @@ function temporary_quantities(Y, atmos)
ᶜtemp_CT3 = Fields.Field(CT3{FT}, center_space), # ᶜω³, ᶜ∇Φ³
ᶠtemp_CT3 = Fields.Field(CT3{FT}, face_space), # ᶠuₕ³
ᶠtemp_CT12 = Fields.Field(CT12{FT}, face_space), # ᶠω¹²
ᶠtemp_CT12ʲs = Fields.Field(NTuple{n, CT12{FT}}, face_space), # ᶠω¹²ʲs
ᶠtemp_CT12ʲs = Fields.Field(
NTuple{n_mass_flux_subdomains(atmos.turbconv_model), CT12{FT}},
face_space,
), # ᶠω¹²ʲs
ᶠtemp_C123 = Fields.Field(C123{FT}, face_space), # χ₁₂₃
ᶜtemp_UVWxUVW = Fields.Field(
typeof(UVW(FT(0), FT(0), FT(0)) * UVW(FT(0), FT(0), FT(0))'),
Expand Down
48 changes: 30 additions & 18 deletions src/prognostic_equations/hyperdiffusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,18 @@ import ClimaCore.Geometry as Geometry
import ClimaCore.Fields as Fields
import ClimaCore.Spaces as Spaces

function hyperdiffusion_cache(Y, atmos, do_dss)
isnothing(atmos.hyperdiff) && return (;)
hyperdiffusion_cache(Y, atmos) =
hyperdiffusion_cache(Y, atmos.hyperdiff, atmos.turbconv_model)

# No hyperdiffiusion
hyperdiffusion_cache(Y, hyperdiff::Nothing, _) = (;)

function hyperdiffusion_cache(Y, hyperdiff::ClimaHyperdiffusion, turbconv_model)
do_dss =
Spaces.horizontal_space(axes(Y.c)).quadrature_style isa
Spaces.Quadratures.GLL
FT = eltype(Y)
n = n_mass_flux_subdomains(atmos.turbconv_model)
n = n_mass_flux_subdomains(turbconv_model)

# Grid scale quantities
ᶜ∇²u = similar(Y.c, C123{FT})
Expand All @@ -21,19 +29,19 @@ function hyperdiffusion_cache(Y, atmos, do_dss)

# Sub-grid scale quantities
ᶜ∇²uʲs =
atmos.turbconv_model isa PrognosticEDMFX ?
similar(Y.c, NTuple{n, C123{FT}}) : (;)
turbconv_model isa PrognosticEDMFX ? similar(Y.c, NTuple{n, C123{FT}}) :
(;)
sgs_quantities =
atmos.turbconv_model isa PrognosticEDMFX ?
turbconv_model isa PrognosticEDMFX ?
(;
ᶜ∇²tke⁰ = similar(Y.c, FT),
ᶜ∇²uₕʲs = similar(Y.c, NTuple{n, C12{FT}}),
ᶜ∇²uᵥʲs = similar(Y.c, NTuple{n, C3{FT}}),
ᶜ∇²h_totʲs = similar(Y.c, NTuple{n, FT}),
ᶜ∇²q_totʲs = similar(Y.c, NTuple{n, FT}),
) :
atmos.turbconv_model isa DiagnosticEDMFX ?
(; ᶜ∇²tke⁰ = similar(Y.c, FT)) : (;)
turbconv_model isa DiagnosticEDMFX ? (; ᶜ∇²tke⁰ = similar(Y.c, FT)) :
(;)
quantities = (; gs_quantities..., sgs_quantities...)
if do_dss
quantities = (;
Expand All @@ -56,15 +64,18 @@ NVTX.@annotate function hyperdiffusion_tendency!(Yₜ, Y, p, t)
diffuse_tke = use_prognostic_tke(turbconv_model)
ᶜJ = Fields.local_geometry_field(Y.c).J
point_type = eltype(Fields.coordinate_field(Y.c))
(; do_dss, ᶜp, ᶜspecific, ᶜ∇²u, ᶜ∇²specific_energy) = p
(; do_dss, ᶜp, ᶜspecific) = p
(; ᶜ∇²u, ᶜ∇²specific_energy) = p.hyperdiff
if turbconv_model isa PrognosticEDMFX
(; ᶜρa⁰, ᶜρʲs, ᶜ∇²tke⁰, ᶜtke⁰, ᶜ∇²uₕʲs, ᶜ∇²uᵥʲs, ᶜ∇²uʲs, ᶜ∇²h_totʲs) = p
(; ᶜρa⁰, ᶜtke⁰) = p
(; ᶜ∇²tke⁰, ᶜ∇²uₕʲs, ᶜ∇²uᵥʲs, ᶜ∇²uʲs, ᶜ∇²h_totʲs) = p.hyperdiff
end
if turbconv_model isa DiagnosticEDMFX
(; ᶜtke⁰, ᶜ∇²tke⁰) = p
(; ᶜtke⁰) = p
(; ᶜ∇²tke⁰) = p.hyperdiff
end
if do_dss
buffer = p.hyperdiffusion_ghost_buffer
buffer = p.hyperdiff.hyperdiffusion_ghost_buffer
end

# Grid scale hyperdiffusion
Expand Down Expand Up @@ -97,16 +108,16 @@ NVTX.@annotate function hyperdiffusion_tendency!(Yₜ, Y, p, t)
Spaces.weighted_dss_ghost!,
)
# DSS on Grid scale quantities
# Need to split the DSS computation here, because our DSS
# Need to split the DSS computation here, because our DSS
# operations do not accept Covariant123Vector types
dss_op!(ᶜ∇²u, buffer.ᶜ∇²u)
dss_op!(ᶜ∇²specific_energy, buffer.ᶜ∇²specific_energy)
if diffuse_tke
dss_op!(ᶜ∇²tke⁰, buffer.ᶜ∇²tke⁰)
end
if turbconv_model isa PrognosticEDMFX
# Need to split the DSS computation here, because our DSS
# operations do not accept Covariant123Vector types
# Need to split the DSS computation here, because our DSS
# operations do not accept Covariant123Vector types
for j in 1:n
@. ᶜ∇²uₕʲs.:($$j) = C12(ᶜ∇²uʲs.:($$j))
@. ᶜ∇²uᵥʲs.:($$j) = C3(ᶜ∇²uʲs.:($$j))
Expand Down Expand Up @@ -162,12 +173,13 @@ NVTX.@annotate function tracer_hyperdiffusion_tendency!(Yₜ, Y, p, t)
(; κ₄) = hyperdiff
n = n_mass_flux_subdomains(turbconv_model)

(; do_dss, ᶜspecific, ᶜ∇²specific_tracers) = p
(; do_dss, ᶜspecific) = p
(; ᶜ∇²specific_tracers) = p.hyperdiff
if turbconv_model isa PrognosticEDMFX
(; ᶜ∇²q_totʲs) = p
(; ᶜ∇²q_totʲs) = p.hyperdiff
end
if do_dss
buffer = p.hyperdiffusion_ghost_buffer
buffer = p.hyperdiff.hyperdiffusion_ghost_buffer
end

for χ_name in propertynames(ᶜ∇²specific_tracers)
Expand Down

0 comments on commit e82cf76

Please sign in to comment.