Skip to content

Commit

Permalink
Add copy to compute! functions for radiation
Browse files Browse the repository at this point in the history
The radiation diagnostics were returning references to objects, leading
to conflicts when multiple diagnostics were added at the same time.

This commit ensures that all the radiation diagnostics return a new
object.

CliMA/ClimaDiagnostics.jl#88 fixes the problem
direclty in ClimaDiagnostics, but I think it is still worth fixing in
atmos, too, so that we can be compatible with previous versions of
ClimaDiagnostics and more uniform with what we are doing with all the
other diagnostics
  • Loading branch information
Sbozzolo committed Oct 7, 2024
1 parent 9ee1c74 commit ac3e831
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 81 deletions.
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ ClimaAtmos.jl Release Notes
Main
-------

### ![][badge-🐛bugfix] Bug fixes

- Fixed radiation diagnostics conflicting with each other. Prior to this change,
adding multiple diagnostics associated to the same variable would lead to
incorrect results when the more diagnostics were output at the same time. PR
[3365](https://github.com/CliMA/ClimaAtmos.jl/pull/3365)

v0.27.6
-------
Expand Down
200 changes: 119 additions & 81 deletions src/diagnostics/radiation_diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ function compute_rsd!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -56,12 +58,14 @@ function compute_rsdt!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -99,12 +103,14 @@ function compute_rsds!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_dn,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -142,9 +148,11 @@ function compute_rsu!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -180,12 +188,14 @@ function compute_rsut!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -223,12 +233,14 @@ function compute_rsus!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_sw_flux_up,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -266,9 +278,11 @@ function compute_rld!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_dn,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_dn,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -303,12 +317,14 @@ function compute_rlds!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_dn,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -346,9 +362,11 @@ function compute_rlu!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -384,12 +402,14 @@ function compute_rlut!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -427,12 +447,14 @@ function compute_rlus!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_lw_flux_up,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -507,12 +529,14 @@ function compute_rsdscs!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_dn,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -550,9 +574,11 @@ function compute_rsucs!(
radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics,
)
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -588,12 +614,14 @@ function compute_rsutcs!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -631,12 +659,14 @@ function compute_rsuscs!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_sw_flux_up,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -675,9 +705,11 @@ function compute_rldcs!(
radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics,
)
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_dn,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_dn,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -712,12 +744,14 @@ function compute_rldscs!(
radiation_mode::T,
) where {T <: RRTMGPI.AbstractRRTMGPMode}
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_dn,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_dn,
axes(state.f),
),
half,
),
half,
)
else
out .= Fields.level(
Expand Down Expand Up @@ -755,9 +789,11 @@ function compute_rlucs!(
radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics,
)
if isnothing(out)
return Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_up,
axes(state.f),
return copy(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_up,
axes(state.f),
),
)
else
out .= Fields.array2field(
Expand Down Expand Up @@ -793,12 +829,14 @@ function compute_rlutcs!(
) where {T <: RRTMGPI.AbstractRRTMGPMode}
nlevels = Spaces.nlevels(axes(state.c))
if isnothing(out)
return Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_up,
axes(state.f),
return copy(
Fields.level(
Fields.array2field(
cache.radiation.rrtmgp_model.face_clear_lw_flux_up,
axes(state.f),
),
nlevels + half,
),
nlevels + half,
)
else
out .= Fields.level(
Expand Down

0 comments on commit ac3e831

Please sign in to comment.