From 569b5314857c389cd018bfd88ee9d383f6cc9e6e Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Mon, 7 Oct 2024 14:43:12 -0700 Subject: [PATCH] Add copy to compute! functions for radiation The radiation diagnostics were returning references to objects, leading to conflicts when multiple diagnostics were added at the same time. This commit ensures that all the radiation diagnostics return a new object. https://github.com/CliMA/ClimaDiagnostics.jl/pull/88 fixes the problem direclty in ClimaDiagnostics, but I think it is still worth fixing in atmos, too, so that we can be compatible with previous versions of ClimaDiagnostics and more uniform with what we are doing with all the other diagnostics --- src/diagnostics/radiation_diagnostics.jl | 200 ++++++++++++++--------- 1 file changed, 119 insertions(+), 81 deletions(-) diff --git a/src/diagnostics/radiation_diagnostics.jl b/src/diagnostics/radiation_diagnostics.jl index abbd0b52f60..28eae807957 100644 --- a/src/diagnostics/radiation_diagnostics.jl +++ b/src/diagnostics/radiation_diagnostics.jl @@ -18,9 +18,11 @@ function compute_rsd!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.array2field( - cache.radiation.rrtmgp_model.face_sw_flux_dn, - axes(state.f), + return copy( + Fields.array2field( + cache.radiation.rrtmgp_model.face_sw_flux_dn, + axes(state.f), + ), ) else out .= Fields.array2field( @@ -56,12 +58,14 @@ function compute_rsdt!( ) where {T <: RRTMGPI.AbstractRRTMGPMode} nlevels = Spaces.nlevels(axes(state.c)) if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_sw_flux_dn, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_sw_flux_dn, + axes(state.f), + ), + nlevels + half, ), - nlevels + half, ) else out .= Fields.level( @@ -99,12 +103,14 @@ function compute_rsds!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_sw_flux_dn, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_sw_flux_dn, + axes(state.f), + ), + half, ), - half, ) else out .= Fields.level( @@ -142,9 +148,11 @@ function compute_rsu!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.array2field( - cache.radiation.rrtmgp_model.face_sw_flux_up, - axes(state.f), + return copy( + Fields.array2field( + cache.radiation.rrtmgp_model.face_sw_flux_up, + axes(state.f), + ), ) else out .= Fields.array2field( @@ -180,12 +188,14 @@ function compute_rsut!( ) where {T <: RRTMGPI.AbstractRRTMGPMode} nlevels = Spaces.nlevels(axes(state.c)) if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_sw_flux_up, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_sw_flux_up, + axes(state.f), + ), + nlevels + half, ), - nlevels + half, ) else out .= Fields.level( @@ -223,12 +233,14 @@ function compute_rsus!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_sw_flux_up, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_sw_flux_up, + axes(state.f), + ), + half, ), - half, ) else out .= Fields.level( @@ -266,9 +278,11 @@ function compute_rld!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.array2field( - cache.radiation.rrtmgp_model.face_lw_flux_dn, - axes(state.f), + return copy( + Fields.array2field( + cache.radiation.rrtmgp_model.face_lw_flux_dn, + axes(state.f), + ), ) else out .= Fields.array2field( @@ -303,12 +317,14 @@ function compute_rlds!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_lw_flux_dn, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_lw_flux_dn, + axes(state.f), + ), + half, ), - half, ) else out .= Fields.level( @@ -346,9 +362,11 @@ function compute_rlu!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.array2field( - cache.radiation.rrtmgp_model.face_lw_flux_up, - axes(state.f), + return copy( + Fields.array2field( + cache.radiation.rrtmgp_model.face_lw_flux_up, + axes(state.f), + ), ) else out .= Fields.array2field( @@ -384,12 +402,14 @@ function compute_rlut!( ) where {T <: RRTMGPI.AbstractRRTMGPMode} nlevels = Spaces.nlevels(axes(state.c)) if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_lw_flux_up, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_lw_flux_up, + axes(state.f), + ), + nlevels + half, ), - nlevels + half, ) else out .= Fields.level( @@ -427,12 +447,14 @@ function compute_rlus!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_lw_flux_up, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_lw_flux_up, + axes(state.f), + ), + half, ), - half, ) else out .= Fields.level( @@ -507,12 +529,14 @@ function compute_rsdscs!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_clear_sw_flux_dn, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_clear_sw_flux_dn, + axes(state.f), + ), + half, ), - half, ) else out .= Fields.level( @@ -550,9 +574,11 @@ function compute_rsucs!( radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics, ) if isnothing(out) - return Fields.array2field( - cache.radiation.rrtmgp_model.face_clear_sw_flux_up, - axes(state.f), + return copy( + Fields.array2field( + cache.radiation.rrtmgp_model.face_clear_sw_flux_up, + axes(state.f), + ), ) else out .= Fields.array2field( @@ -588,12 +614,14 @@ function compute_rsutcs!( ) where {T <: RRTMGPI.AbstractRRTMGPMode} nlevels = Spaces.nlevels(axes(state.c)) if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_clear_sw_flux_up, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_clear_sw_flux_up, + axes(state.f), + ), + nlevels + half, ), - nlevels + half, ) else out .= Fields.level( @@ -631,12 +659,14 @@ function compute_rsuscs!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_clear_sw_flux_up, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_clear_sw_flux_up, + axes(state.f), + ), + half, ), - half, ) else out .= Fields.level( @@ -675,9 +705,11 @@ function compute_rldcs!( radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics, ) if isnothing(out) - return Fields.array2field( - cache.radiation.rrtmgp_model.face_clear_lw_flux_dn, - axes(state.f), + return copy( + Fields.array2field( + cache.radiation.rrtmgp_model.face_clear_lw_flux_dn, + axes(state.f), + ), ) else out .= Fields.array2field( @@ -712,12 +744,14 @@ function compute_rldscs!( radiation_mode::T, ) where {T <: RRTMGPI.AbstractRRTMGPMode} if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_clear_lw_flux_dn, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_clear_lw_flux_dn, + axes(state.f), + ), + half, ), - half, ) else out .= Fields.level( @@ -755,9 +789,11 @@ function compute_rlucs!( radiation_mode::RRTMGPI.AllSkyRadiationWithClearSkyDiagnostics, ) if isnothing(out) - return Fields.array2field( - cache.radiation.rrtmgp_model.face_clear_lw_flux_up, - axes(state.f), + return copy( + Fields.array2field( + cache.radiation.rrtmgp_model.face_clear_lw_flux_up, + axes(state.f), + ), ) else out .= Fields.array2field( @@ -793,12 +829,14 @@ function compute_rlutcs!( ) where {T <: RRTMGPI.AbstractRRTMGPMode} nlevels = Spaces.nlevels(axes(state.c)) if isnothing(out) - return Fields.level( - Fields.array2field( - cache.radiation.rrtmgp_model.face_clear_lw_flux_up, - axes(state.f), + return copy( + Fields.level( + Fields.array2field( + cache.radiation.rrtmgp_model.face_clear_lw_flux_up, + axes(state.f), + ), + nlevels + half, ), - nlevels + half, ) else out .= Fields.level(