Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to fix FieldVector inference #2079

Merged
merged 3 commits into from
Sep 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion perf/flame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ allocs_limit["flame_perf_target"] = 4384
allocs_limit["flame_perf_target_tracers"] = 204016
allocs_limit["flame_perf_target_edmfx"] = 304064
allocs_limit["flame_perf_target_diagnostic_edmfx"] = 685456
allocs_limit["flame_perf_target_edmf"] = 9015243600
allocs_limit["flame_perf_target_edmf"] = 12459299664
allocs_limit["flame_perf_target_threaded"] = 6175664
allocs_limit["flame_perf_target_callbacks"] = 49850536
allocs_limit["flame_perf_gw"] = 4985829472
Expand Down
4 changes: 2 additions & 2 deletions src/TurbulenceConvection_deprecated/dycore_api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ struct CentField <: FieldLocation end
struct FaceField <: FieldLocation end
struct SingleValuePerColumn <: FieldLocation end

field_loc(::CentField) = :cent
field_loc(::FaceField) = :face
field_loc(::CentField) = :c
field_loc(::FaceField) = :f
field_loc(::SingleValuePerColumn) = :svpc

#=
Expand Down
41 changes: 18 additions & 23 deletions src/TurbulenceConvection_deprecated/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,27 +206,28 @@ struct State{P, A, T, CACHE, C, SC}
surface_conditions::SC
end

Grid(state::State) = Grid(axes(state.prog.cent))
Grid(state::State) = Grid(axes(state.prog.c))

float_type(state::State) = eltype(state.prog)
# float_type(field::CC.Fields.Field) = CC.Spaces.undertype(axes(field))
float_type(field::CC.Fields.Field) = eltype(parent(field))

import ClimaCore.Fields as Fields
import ClimaCore.Spaces as Spaces


Base.@propagate_inbounds function field_vector_column(
fv::Fields.FieldVector{T},
colidx::Fields.ColumnIndex,
) where {T}
values = map(x -> x[colidx], Fields._values(fv))
return Fields.FieldVector{T, typeof(values)}(values)
end

function tc_column_state(prog, p, tendencies, colidx, t)
prog_cent_column = CC.column(prog.c, colidx)
prog_face_column = CC.column(prog.f, colidx)
aux_cent_column = CC.column(p.edmf_cache.aux.cent, colidx)
aux_face_column = CC.column(p.edmf_cache.aux.face, colidx)
tends_cent_column = CC.column(tendencies.c, colidx)
tends_face_column = CC.column(tendencies.f, colidx)
prog_column =
CC.Fields.FieldVector(cent = prog_cent_column, face = prog_face_column)
aux_column =
CC.Fields.FieldVector(cent = aux_cent_column, face = aux_face_column)
tends_column = CC.Fields.FieldVector(
cent = tends_cent_column,
face = tends_face_column,
)
prog_column = field_vector_column(prog, colidx)
aux_column = field_vector_column(p.edmf_cache.aux, colidx)
tends_column = field_vector_column(tendencies, colidx)
surface_conditions = CC.column(p.sfc_conditions, colidx)[]
return State(
prog_column,
Expand All @@ -239,14 +240,8 @@ function tc_column_state(prog, p, tendencies, colidx, t)
end

function tc_column_state(prog, p, tendencies::Nothing, colidx, t)
prog_cent_column = CC.column(prog.c, colidx)
prog_face_column = CC.column(prog.f, colidx)
aux_cent_column = CC.column(p.edmf_cache.aux.cent, colidx)
aux_face_column = CC.column(p.edmf_cache.aux.face, colidx)
prog_column =
CC.Fields.FieldVector(cent = prog_cent_column, face = prog_face_column)
aux_column =
CC.Fields.FieldVector(cent = aux_cent_column, face = aux_face_column)
prog_column = field_vector_column(prog, colidx)
aux_column = field_vector_column(p.edmf_cache.aux, colidx)
tends_column = nothing
surface_conditions = CC.column(p.sfc_conditions, colidx)[]
return State(
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ function compute_diagnostics(integrator)
) .+ ᶜa⁺ .* cloud_fraction.(ᶜts⁺, ᶜa⁺),
)
elseif p.atmos.turbconv_model isa TC.EDMFModel
tc_cent(p) = p.edmf_cache.aux.cent.turbconv
tc_face(p) = p.edmf_cache.aux.face.turbconv
tc_cent(p) = p.edmf_cache.aux.c.turbconv
tc_face(p) = p.edmf_cache.aux.f.turbconv
turbulence_convection_diagnostic = (;
bulk_up_area = tc_cent(p).bulk.area,
bulk_up_h_tot = tc_cent(p).bulk.h_tot,
Expand Down
21 changes: 8 additions & 13 deletions src/dycore_equations_deprecated/sgs_flux_tendencies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function turbconv_aux(atmos, edmf, Y, ::Type{FT}) where {FT}

aux_cent_fields = cent_aux_vars.(FT, ᶜlocal_geometry, atmos, edmf)
aux_face_fields = face_aux_vars.(FT, ᶠlocal_geometry, atmos, edmf)
aux = Fields.FieldVector(cent = aux_cent_fields, face = aux_face_fields)
aux = Fields.FieldVector(; c = aux_cent_fields, f = aux_face_fields)
return aux
end

Expand Down Expand Up @@ -174,8 +174,8 @@ function implicit_sgs_flux_tendency!(Yₜ, Y, p, t, colidx, ::TC.EDMFModel)

grid = TC.Grid(state)
if test_consistency
parent(state.aux.face) .= NaN
parent(state.aux.cent) .= NaN
parent(state.aux.f) .= NaN
parent(state.aux.c) .= NaN
end

assign_thermo_aux!(state, edmf.moisture_model, thermo_params)
Expand Down Expand Up @@ -213,8 +213,8 @@ function explicit_sgs_flux_tendency!(Yₜ, Y, p, t, colidx, ::TC.EDMFModel)

grid = TC.Grid(state)
if test_consistency
parent(state.aux.face) .= NaN
parent(state.aux.cent) .= NaN
parent(state.aux.f) .= NaN
parent(state.aux.c) .= NaN
end

assign_thermo_aux!(state, edmf.moisture_model, thermo_params)
Expand All @@ -237,14 +237,9 @@ function explicit_sgs_flux_tendency!(Yₜ, Y, p, t, colidx, ::TC.EDMFModel)
# Note: This "filter relaxation tendency" can be scaled down if needed, but
# it must be present in order to prevent Y and Y_filtered from diverging
# during each timestep.
Yₜ_turbconv =
Fields.FieldVector(c = Yₜ.c.turbconv[colidx], f = Yₜ.f.turbconv[colidx])
Y_filtered_turbconv = Fields.FieldVector(
c = Y_filtered.c.turbconv[colidx],
f = Y_filtered.f.turbconv[colidx],
)
Y_turbconv =
Fields.FieldVector(c = Y.c.turbconv[colidx], f = Y.f.turbconv[colidx])
Yₜ_turbconv = TC.field_vector_column(Yₜ, colidx)
Y_filtered_turbconv = TC.field_vector_column(Y_filtered, colidx)
Y_turbconv = TC.field_vector_column(Y, colidx)
Yₜ_turbconv .+= (Y_filtered_turbconv .- Y_turbconv) ./ Δt
return nothing
end
20 changes: 10 additions & 10 deletions src/parameterized_tendencies/microphysics/precipitation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ function compute_precipitation_cache!(
)
(; ᶜS_ρq_tot) = p
qt_tendency_precip_formation_en =
p.edmf_cache.aux.cent.turbconv.en.qt_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.en.qt_tendency_precip_formation[colidx]
qt_tendency_precip_formation_bulk =
p.edmf_cache.aux.cent.turbconv.bulk.qt_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.bulk.qt_tendency_precip_formation[colidx]

@. ᶜS_ρq_tot[colidx] =
Y.c.ρ[colidx] *
Expand Down Expand Up @@ -196,21 +196,21 @@ function compute_precipitation_cache!(

# Sources of precipitation from EDMF SGS sub-domains
e_tot_tendency_precip_formation_en =
p.edmf_cache.aux.cent.turbconv.en.e_tot_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.en.e_tot_tendency_precip_formation[colidx]
e_tot_tendency_precip_formation_bulk =
p.edmf_cache.aux.cent.turbconv.bulk.e_tot_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.bulk.e_tot_tendency_precip_formation[colidx]
qt_tendency_precip_formation_en =
p.edmf_cache.aux.cent.turbconv.en.qt_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.en.qt_tendency_precip_formation[colidx]
qt_tendency_precip_formation_bulk =
p.edmf_cache.aux.cent.turbconv.bulk.qt_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.bulk.qt_tendency_precip_formation[colidx]
qr_tendency_precip_formation_en =
p.edmf_cache.aux.cent.turbconv.en.qr_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.en.qr_tendency_precip_formation[colidx]
qr_tendency_precip_formation_bulk =
p.edmf_cache.aux.cent.turbconv.bulk.qr_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.bulk.qr_tendency_precip_formation[colidx]
qs_tendency_precip_formation_en =
p.edmf_cache.aux.cent.turbconv.en.qs_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.en.qs_tendency_precip_formation[colidx]
qs_tendency_precip_formation_bulk =
p.edmf_cache.aux.cent.turbconv.bulk.qs_tendency_precip_formation[colidx]
p.edmf_cache.aux.c.turbconv.bulk.qs_tendency_precip_formation[colidx]

thermo_params = CAP.thermodynamics_params(params)
cm_params = CAP.microphysics_params(params)
Expand Down