From 4a528bcae57a3d430b545f5cc903aaea76e19844 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Tue, 8 Oct 2024 10:56:23 -0400 Subject: [PATCH] Manually fuse some broadcast expressions in diagnostic edmf --- .../diagnostic_edmf_precomputed_quantities.jl | 112 +++++++++--------- 1 file changed, 58 insertions(+), 54 deletions(-) diff --git a/src/cache/diagnostic_edmf_precomputed_quantities.jl b/src/cache/diagnostic_edmf_precomputed_quantities.jl index b8e12c23ba..cd03f0e52d 100644 --- a/src/cache/diagnostic_edmf_precomputed_quantities.jl +++ b/src/cache/diagnostic_edmf_precomputed_quantities.jl @@ -216,6 +216,51 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_bottom_bc!( return nothing end +function compute_u³ʲ_u³ʲ( + u³ʲ_prev_halflevel, + J_prev_halflevel, + J_halflevel, + J_prev_level, + ∇Φ³_data_prev_level, + ρʲ_prev_level, + ρ_prev_level, + entrʲ_prev_level, + turb_entrʲ_prev_level, + u³⁰_data_prev_halflevel, + nh_pressure³ʲ_data_prev_halflevel, +) + u³ʲ_u³ʲ = + (1 / (J_halflevel^2)) * + (J_prev_halflevel^2 * u³ʲ_prev_halflevel * u³ʲ_prev_halflevel) + + u³ʲ_u³ʲ -= + (1 / (J_halflevel^2)) * ( + J_prev_level^2 * + 2 * + ( + ∇Φ³_data_prev_level * (ρʲ_prev_level - ρ_prev_level) / + ρʲ_prev_level + ) + ) + + u³ʲ_u³ʲ += + (1 / (J_halflevel^2)) * ( + J_prev_level^2 * + 2 * + ( + (entrʲ_prev_level + turb_entrʲ_prev_level) * + u³⁰_data_prev_halflevel - + (entrʲ_prev_level + turb_entrʲ_prev_level) * u³ʲ_prev_halflevel + ) + ) + + u³ʲ_u³ʲ -= + (1 / (J_halflevel^2)) * + (J_prev_level^2 * 2 * nh_pressure³ʲ_data_prev_halflevel) + return u³ʲ_u³ʲ +end + + NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!( Y, p, @@ -490,60 +535,19 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!( end u³ʲ_datau³ʲ_data = p.scratch.temp_data_level - # Using constant exponents in broadcasts allocate, so we use - # local_geometry_halflevel.J * local_geometry_halflevel.J instead. - # See ClimaCore.jl issue #1126. - @. u³ʲ_datau³ʲ_data = - ( - 1 / - (local_geometry_halflevel.J * local_geometry_halflevel.J) - ) * ( - local_geometry_prev_halflevel.J * - local_geometry_prev_halflevel.J * - u³ʲ_data_prev_halflevel * - u³ʲ_data_prev_halflevel - ) - - @. u³ʲ_datau³ʲ_data -= - ( - 1 / - (local_geometry_halflevel.J * local_geometry_halflevel.J) - ) * ( - local_geometry_prev_level.J * - local_geometry_prev_level.J * - 2 * - ( - ∇Φ³_data_prev_level * (ρʲ_prev_level - ρ_prev_level) / - ρʲ_prev_level - ) - ) - - @. u³ʲ_datau³ʲ_data += - ( - 1 / - (local_geometry_halflevel.J * local_geometry_halflevel.J) - ) * ( - local_geometry_prev_level.J * - local_geometry_prev_level.J * - 2 * - ( - (entrʲ_prev_level + turb_entrʲ_prev_level) * - u³⁰_data_prev_halflevel - - (entrʲ_prev_level + turb_entrʲ_prev_level) * - u³ʲ_data_prev_halflevel - ) - ) - - @. u³ʲ_datau³ʲ_data -= - ( - 1 / - (local_geometry_halflevel.J * local_geometry_halflevel.J) - ) * ( - local_geometry_prev_level.J * - local_geometry_prev_level.J * - 2 * - nh_pressure³ʲ_data_prev_halflevel - ) + @. u³ʲ_datau³ʲ_data = compute_u³ʲ_u³ʲ( + u³ʲ_data_prev_halflevel, + local_geometry_prev_halflevel.J, + local_geometry_halflevel.J, + local_geometry_prev_level.J, + ∇Φ³_data_prev_level, + ρʲ_prev_level, + ρ_prev_level, + entrʲ_prev_level, + turb_entrʲ_prev_level, + u³⁰_data_prev_halflevel, + nh_pressure³ʲ_data_prev_halflevel, + ) # get u³ʲ to calculate divergence term for detrainment, # u³ʲ will be clipped later after we get area fraction