From 9f6a4eca180d05fe8254a33d591d43ded82bfb9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E8=90=A7=E6=B6=AF?= Date: Thu, 8 Aug 2024 20:04:11 -0700 Subject: [PATCH] optimize efficiency --- artifacts/Artifacts.toml | 2 + .../non_orographic_gravity_wave.jl | 432 ++++++++---------- 2 files changed, 198 insertions(+), 236 deletions(-) create mode 100644 artifacts/Artifacts.toml diff --git a/artifacts/Artifacts.toml b/artifacts/Artifacts.toml new file mode 100644 index 00000000000..eb5e1bce1c7 --- /dev/null +++ b/artifacts/Artifacts.toml @@ -0,0 +1,2 @@ +[topo-elev-info] +git-tree-sha1 = "845995bb777cf5a3920541585d62c087f62e5cb5" diff --git a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl index 8b4df7d29de..9ceacec5fcd 100644 --- a/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl +++ b/src/parameterized_tendencies/gravity_wave_drag/non_orographic_gravity_wave.jl @@ -165,22 +165,9 @@ function non_orographic_gravity_wave_tendency!( uforcing, vforcing, ᶜlevel, - gw_cmax, - gw_dc, gw_ncval, ) = p.non_orographic_gravity_wave (; model_config) = p.atmos - (; - gw_source_ampl, - gw_Bw, - gw_Bn, - gw_c, - gw_cw, - gw_cn, - gw_flag, - gw_c0, - gw_nk, - ) = p.non_orographic_gravity_wave if model_config isa SingleColumnModel (; gw_source_height, source_ρ_z_u_v_level) = @@ -279,11 +266,6 @@ function non_orographic_gravity_wave_tendency!( ᶜu = Geometry.UVVector.(Y.c.uₕ).components.data.:1 ᶜv = Geometry.UVVector.(Y.c.uₕ).components.data.:2 - ᶜρ_p1 = p.scratch.ᶜtemp_scalar - ᶜz_p1 = p.scratch.ᶜtemp_scalar_2 - ᶜu_p1 = p.scratch.ᶜtemp_scalar_3 - ᶜv_p1 = p.scratch.ᶜtemp_scalar_4 - ᶜbf_p1 = p.scratch.ᶜtemp_scalar_5 uforcing .= 0 vforcing .= 0 @@ -297,30 +279,15 @@ function non_orographic_gravity_wave_tendency!( ᶜlevel, source_level, damp_level, - gw_source_ampl, ᶜρ_source, ᶜu_source, ᶜv_source, uforcing, vforcing, - gw_Bw, - gw_Bn, - gw_cw, - gw_cn, - gw_flag, gw_ncval, - gw_c, - gw_c0, - gw_nk, - gw_cmax, - gw_dc, u_waveforcing, v_waveforcing, - ᶜρ_p1, - ᶜu_p1, - ᶜv_p1, - ᶜbf_p1, - ᶜz_p1, + p, ) @. Yₜ.c.uₕ += @@ -337,31 +304,32 @@ function non_orographic_gravity_wave_forcing( ᶜlevel, source_level, damp_level, - gw_source_ampl, ᶜρ_source, ᶜu_source, ᶜv_source, uforcing, vforcing, - gw_Bw, - gw_Bn, - gw_cw, - gw_cn, - gw_flag, gw_ncval::Val{nc}, - gw_c, - c0, - nk, - cmax, - dc, u_waveforcing, v_waveforcing, - ᶜρ_p1, - ᶜu_p1, - ᶜv_p1, - ᶜbf_p1, - ᶜz_p1, + p, ) where {nc} + (; + gw_source_ampl, + gw_Bw, + gw_Bn, + gw_c, + gw_cw, + gw_cn, + gw_flag, + gw_c0, + gw_nk, + ) = p.non_orographic_gravity_wave + ᶜρ_p1 = p.scratch.ᶜtemp_scalar + ᶜz_p1 = p.scratch.ᶜtemp_scalar_2 + ᶜu_p1 = p.scratch.ᶜtemp_scalar_3 + ᶜv_p1 = p.scratch.ᶜtemp_scalar_4 + ᶜbf_p1 = p.scratch.ᶜtemp_scalar_5 nci = get_nc(gw_ncval) FT = eltype(ᶜρ) ρ_end = Fields.level(ᶜρ, Spaces.nlevels(axes(ᶜρ))) @@ -414,7 +382,7 @@ function non_orographic_gravity_wave_forcing( level_end = Spaces.nlevels(axes(ᶜρ)) B1 = ntuple(i -> 0.0, Val(nc)) - for ink in 1:nk + for ink in 1:gw_nk input_u = Base.Broadcast.broadcasted( tuple, @@ -462,7 +430,7 @@ function non_orographic_gravity_wave_forcing( Operators.column_accumulate!( u_waveforcing, input_u; - init = (Float32(0.0), mask, 0.0, B1), + init = (FT(0.0), mask, 0.0, B1), transform = first, ) do (wave_forcing, mask, Bsum, B0), ( @@ -484,108 +452,103 @@ function non_orographic_gravity_wave_forcing( level, source_ampl, ) - - FT1 = typeof(u_kp1) - kwv = 2.0 * π / ((30.0 * (10.0^ink)) * 1.e3) - k2 = kwv * kwv - - # loop over all wave lengths - - - # here ᶜu has one additional level above model top - fac = FT1(0.5) * (ρ_kp1 / ρ_source) * kwv / bf_kp1 - Hb = (z_kp1 - z_k) / log(ρ_k / ρ_kp1) # density scale height - alp2 = FT1(0.25) / (Hb * Hb) - ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection) - - fm = 0.0 - if level == 1 - #mask = ntuple(i -> true, Val(nc)) - mask .= 1 - Bsum = 0.0 - B0 = ntuple( - n -> - sign(((n - 1) * dc - cmax - u_kp1)) * ( - Bw * exp( - -log(2.0) * - ( + if level >= (source_level - 1) + FT1 = typeof(u_kp1) + kwv = 2.0 * π / ((30.0 * (10.0^ink)) * 1.e3) + k2 = kwv * kwv + fac = FT1(0.5) * (ρ_kp1 / ρ_source) * kwv / bf_kp1 + Hb = (z_kp1 - z_k) / log(ρ_k / ρ_kp1) # density scale height + alp2 = FT1(0.25) / (Hb * Hb) + ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection) + fm = 0.0 + if level == (source_level - 1) + mask .= 1 + Bsum = 0.0 + B0 = ntuple( + n -> + sign(c[n] - u_source) * ( + Bw * exp( + -log(2.0) * ( - ((n - 1) * dc - cmax) * flag + - ((n - 1) * dc - cmax - u_source) * - (1 - flag) - c0 - ) / cw - )^2, - ) + - Bn * exp( - -log(2.0) * - ( + ( + c[n] * flag + + (c[n] - u_source) * (1 - flag) - + gw_c0 + ) / cw + )^2, + ) + + Bn * exp( + -log(2.0) * ( - ((n - 1) * dc - cmax) * flag + - ((n - 1) * dc - cmax - u_source) * - (1 - flag) - c0 - ) / cn - )^2, - ) - ), - Val(nc), - ) - Bsum = sum(abs.(B0)) - end - for n in 1:nci - # check only those waves which are still propagating, i.e., mask = 1.0 - if (mask[n]) == 1 - c_hat = c[n] - u_kp1 # c0mu - # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. - if c_hat == 0.0 - mask[n] = 0 - else - c_hat0 = c[n] - u_source - # define the criterion which determines if wave is reflected at this level (test). - test = abs(c_hat) * kwv - ω_r - if test >= 0.0 - # wave has undergone total internal reflection. remove it from the propagating set. + ( + c[n] * flag + + (c[n] - u_source) * (1 - flag) - + gw_c0 + ) / cn + )^2, + ) + ), + Val(nc), + ) + Bsum = sum(abs.(B0)) + end + for n in 1:nci + # check only those waves which are still propagating, i.e., mask = 1.0 + if (mask[n]) == 1 + c_hat = c[n] - u_kp1 # c0mu + # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. + if c_hat == 0.0 mask[n] = 0 else - if level == level_end - # this is added in MiMA implementation: - # all momentum flux that escapes across the model top - # is deposited to the extra level being added so that - # momentum flux is conserved + c_hat0 = c[n] - u_source + # define the criterion which determines if wave is reflected at this level (test). + test = abs(c_hat) * kwv - ω_r + if test >= 0.0 + # wave has undergone total internal reflection. remove it from the propagating set. mask[n] = 0 - if level >= source_level - fm = fm + B0[n] - end else - # if wave is not reflected at this level, determine if it is - # breaking at this level (Foc >= 0), or if wave speed relative to - # windspeed has changed sign from its value at the source level - # (c_hat0[n] * c_hat <= 0). if it is above the source level and is - # breaking, then add its momentum flux to the accumulated sum at - # this level. - # set mask=0.0 to remove phase speed band c[n] from the set of active - # waves moving upwards to the next level. - Foc = B0[n] / FT1((c_hat)^3) - fac - if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) + if level == level_end + # this is added in MiMA implementation: + # all momentum flux that escapes across the model top + # is deposited to the extra level being added so that + # momentum flux is conserved mask[n] = 0 if level >= source_level fm = fm + B0[n] end + else + # if wave is not reflected at this level, determine if it is + # breaking at this level (Foc >= 0), or if wave speed relative to + # windspeed has changed sign from its value at the source level + # (c_hat0[n] * c_hat <= 0). if it is above the source level and is + # breaking, then add its momentum flux to the accumulated sum at + # this level. + # set mask=0.0 to remove phase speed band c[n] from the set of active + # waves moving upwards to the next level. + Foc = B0[n] / FT1((c_hat)^3) - fac + if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) + mask[n] = 0 + if level >= source_level + fm = fm + B0[n] + end + end end - end - end # (test >= 0.0) - end #(c_hat == 0.0) - end # mask = 0 - - end # nc: phase speed loop - - # compute the gravity wave momentum flux forcing - # obtained across the entire wave spectrum at this level. - eps = calc_intermitency(ρ_source, source_ampl, nk, FT1(Bsum)) - if level >= source_level - rbh = sqrt(ρ_k * ρ_kp1) - wave_forcing = (ρ_source / rbh) * FT1(fm) * eps / (z_kp1 - z_k) - else - wave_forcing = FT1(0.0) + end # (test >= 0.0) + end #(c_hat == 0.0) + end # mask = 0 + + end # nc: phase speed loop + + # compute the gravity wave momentum flux forcing + # obtained across the entire wave spectrum at this level. + eps = calc_intermitency(ρ_source, source_ampl, gw_nk, FT1(Bsum)) + if level >= source_level + rbh = sqrt(ρ_k * ρ_kp1) + wave_forcing = + (ρ_source / rbh) * FT1(fm) * eps / (z_kp1 - z_k) + else + wave_forcing = FT1(0.0) + end end return (wave_forcing, mask, Bsum, B0) @@ -596,7 +559,7 @@ function non_orographic_gravity_wave_forcing( Operators.column_accumulate!( v_waveforcing, input_v; - init = (Float32(0.0), mask, 0.0, B1), + init = (FT(0.0), mask, 0.0, B1), transform = first, ) do (wave_forcing, mask, Bsum, B0), ( @@ -619,107 +582,104 @@ function non_orographic_gravity_wave_forcing( source_ampl, ) - FT2 = typeof(u_kp1) - kwv = 2.0 * π / ((30.0 * (10.0^ink)) * 1.e3) - k2 = kwv * kwv - - # loop over all wave lengths - - - # here ᶜu has one additional level above model top - fac = FT2(0.5) * (ρ_kp1 / ρ_source) * kwv / bf_kp1 - Hb = (z_kp1 - z_k) / log(ρ_k / ρ_kp1) # density scale height - alp2 = FT2(0.25) / (Hb * Hb) - ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection) - - fm = 0.0 - if level == 1 - #mask = ntuple(i -> true, Val(nc)) - mask .= 1 - Bsum = 0.0 - B0 = ntuple( - n -> - sign(((n - 1) * dc - cmax - u_kp1)) * ( - Bw * exp( - -log(2.0) * - ( + if level >= (source_level - 1) + FT2 = typeof(u_kp1) + kwv = 2.0 * π / ((30.0 * (10.0^ink)) * 1.e3) + k2 = kwv * kwv + fac = FT2(0.5) * (ρ_kp1 / ρ_source) * kwv / bf_kp1 + Hb = (z_kp1 - z_k) / log(ρ_k / ρ_kp1) # density scale height + alp2 = FT2(0.25) / (Hb * Hb) + ω_r = sqrt((bf_kp1 * bf_kp1 * k2) / (k2 + alp2)) # omc: (critical frequency that marks total internal reflection) + + fm = 0.0 + if level == (source_level - 1) + mask .= 1 + Bsum = 0.0 + B0 = ntuple( + n -> + sign((c[n] - u_source)) * ( + Bw * exp( + -log(2.0) * ( - ((n - 1) * dc - cmax) * flag + - ((n - 1) * dc - cmax - u_source) * - (1 - flag) - c0 - ) / cw - )^2, - ) + - Bn * exp( - -log(2.0) * - ( + ( + c[n] * flag + + (c[n] - u_source) * (1 - flag) - + gw_c0 + ) / cw + )^2, + ) + + Bn * exp( + -log(2.0) * ( - ((n - 1) * dc - cmax) * flag + - ((n - 1) * dc - cmax - u_source) * - (1 - flag) - c0 - ) / cn - )^2, - ) - ), - Val(nc), - ) - Bsum = sum(abs.(B0)) - end - for n in 1:nci - # check only those waves which are still propagating, i.e., mask = 1.0 - if (mask[n]) == 1 - c_hat = c[n] - u_kp1 # c0mu - # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. - if c_hat == 0.0 - mask[n] = 0 - else - c_hat0 = c[n] - u_source - # define the criterion which determines if wave is reflected at this level (test). - test = abs(c_hat) * kwv - ω_r - if test >= 0.0 - # wave has undergone total internal reflection. remove it from the propagating set. + ( + c[n] * flag + + (c[n] - u_source) * (1 - flag) - + gw_c0 + ) / cn + )^2, + ) + ), + Val(nc), + ) + Bsum = sum(abs.(B0)) + end + for n in 1:nci + # check only those waves which are still propagating, i.e., mask = 1.0 + if (mask[n]) == 1 + c_hat = c[n] - u_kp1 # c0mu + # f phase speed matches the wind speed, remove c(n) from the set of propagating waves. + if c_hat == 0.0 mask[n] = 0 else - if level == level_end - # this is added in MiMA implementation: - # all momentum flux that escapes across the model top - # is deposited to the extra level being added so that - # momentum flux is conserved + c_hat0 = c[n] - u_source + # define the criterion which determines if wave is reflected at this level (test). + test = abs(c_hat) * kwv - ω_r + if test >= 0.0 + # wave has undergone total internal reflection. remove it from the propagating set. mask[n] = 0 - if level >= source_level - fm = fm + B0[n] - end else - # if wave is not reflected at this level, determine if it is - # breaking at this level (Foc >= 0), or if wave speed relative to - # windspeed has changed sign from its value at the source level - # (c_hat0[n] * c_hat <= 0). if it is above the source level and is - # breaking, then add its momentum flux to the accumulated sum at - # this level. - # set mask=0.0 to remove phase speed band c[n] from the set of active - # waves moving upwards to the next level. - Foc = B0[n] / FT2((c_hat)^3) - fac - if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) + if level == level_end + # this is added in MiMA implementation: + # all momentum flux that escapes across the model top + # is deposited to the extra level being added so that + # momentum flux is conserved mask[n] = 0 if level >= source_level fm = fm + B0[n] end + else + # if wave is not reflected at this level, determine if it is + # breaking at this level (Foc >= 0), or if wave speed relative to + # windspeed has changed sign from its value at the source level + # (c_hat0[n] * c_hat <= 0). if it is above the source level and is + # breaking, then add its momentum flux to the accumulated sum at + # this level. + # set mask=0.0 to remove phase speed band c[n] from the set of active + # waves moving upwards to the next level. + Foc = B0[n] / FT2((c_hat)^3) - fac + if Foc >= 0.0 || (c_hat0 * c_hat <= 0.0) + mask[n] = 0 + if level >= source_level + fm = fm + B0[n] + end + end end - end - end # (test >= 0.0) - end #(c_hat == 0.0) - end # mask = 0 - - end # nc: phase speed loop - - # compute the gravity wave momentum flux forcing - # obtained across the entire wave spectrum at this level. - eps = calc_intermitency(ρ_source, source_ampl, nk, FT2(Bsum)) - if level >= source_level - rbh = sqrt(ρ_k * ρ_kp1) - wave_forcing = (ρ_source / rbh) * FT2(fm) * eps / (z_kp1 - z_k) - else - wave_forcing = FT2(0.0) + end # (test >= 0.0) + end #(c_hat == 0.0) + end # mask = 0 + + end # nc: phase speed loop + + # compute the gravity wave momentum flux forcing + # obtained across the entire wave spectrum at this level. + eps = calc_intermitency(ρ_source, source_ampl, gw_nk, FT2(Bsum)) + if level >= source_level + rbh = sqrt(ρ_k * ρ_kp1) + wave_forcing = + (ρ_source / rbh) * FT2(fm) * eps / (z_kp1 - z_k) + else + wave_forcing = FT2(0.0) + end end return (wave_forcing, mask, Bsum, B0) @@ -793,8 +753,8 @@ end function gw_average!(wave_forcing) L1 = Operators.LeftBiasedC2F(; bottom = Operators.SetValue(0)) L2 = Operators.LeftBiasedF2C(;) - wave_forcing_m1 = L2.(L1.(wave_forcing)) - @. wave_forcing = 0.5 * (wave_forcing + wave_forcing_m1) + #wave_forcing_m1 = L2.(L1.(wave_forcing)) + @. wave_forcing = 0.5 * (wave_forcing .+ L2.(L1.(wave_forcing))) end