Skip to content

Commit

Permalink
Add copy to compute! functions for radiation
Browse files Browse the repository at this point in the history
The radiation diagnostics were returning references to objects, leading
to conflicts when multiple diagnostics were added at the same time.

This commit ensures that all the radiation diagnostics return a new
object.

CliMA/ClimaDiagnostics.jl#88 fixes the problem
direclty in ClimaDiagnostics, but I think it is still worth fixing in
atmos, too, so that we can be compatible with previous versions of
ClimaDiagnostics and more uniform with what we are doing with all the
other diagnostics
  • Loading branch information
Sbozzolo committed Oct 7, 2024
1 parent 632de31 commit 569b531
Showing 1 changed file with 119 additions and 81 deletions.
200 changes: 119 additions & 81 deletions src/diagnostics/radiation_diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ function compute_rsd!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -56,12 +58,14 @@ function compute_rsdt!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -99,12 +103,14 @@ function compute_rsds!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -142,9 +148,11 @@ function compute_rsu!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -180,12 +188,14 @@ function compute_rsut!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -223,12 +233,14 @@ function compute_rsus!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -266,9 +278,11 @@ function compute_rld!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_dn,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_dn,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -303,12 +317,14 @@ function compute_rlds!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_dn,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -346,9 +362,11 @@ function compute_rlu!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -384,12 +402,14 @@ function compute_rlut!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -427,12 +447,14 @@ function compute_rlus!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -507,12 +529,14 @@ function compute_rsdscs!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_dn,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -550,9 +574,11 @@ function compute_rsucs!(
radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics,
)
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -588,12 +614,14 @@ function compute_rsutcs!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -631,12 +659,14 @@ function compute_rsuscs!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -675,9 +705,11 @@ function compute_rldcs!(
radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics,
)
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_dn,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_dn,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -712,12 +744,14 @@ function compute_rldscs!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_dn,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -755,9 +789,11 @@ function compute_rlucs!(
radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics,
)
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_up,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_up,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -793,12 +829,14 @@ function compute_rlutcs!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_up,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down

0 comments on commit 569b531

Please sign in to comment.