Skip to content

Commit

Permalink
Manually fuse some broadcast expressions in diagnostic edmf
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Oct 8, 2024
1 parent a905ac9 commit b32468f
Showing 1 changed file with 58 additions and 54 deletions.
112 changes: 58 additions & 54 deletions src/cache/diagnostic_edmf_precomputed_quantities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,51 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_bottom_bc!(
return nothing
end

function compute_u³ʲ_u³ʲ(
u³ʲ_prev_halflevel,
J_prev_halflevel,
J_halflevel,
J_prev_level,
∇Φ³_data_prev_level,
ρʲ_prev_level,
ρ_prev_level,
entrʲ_prev_level,
turb_entrʲ_prev_level,
u³⁰_data_prev_halflevel,
nh_pressure³ʲ_data_prev_halflevel,
)
u³ʲ_u³ʲ =
(1 / (J_halflevel^2)) *
(J_prev_halflevel^2 * u³ʲ_prev_halflevel * u³ʲ_prev_halflevel)

u³ʲ_u³ʲ -=
(1 / (J_halflevel^2)) * (
J_prev_level^2 *
2 *
(
∇Φ³_data_prev_level * (ρʲ_prev_level - ρ_prev_level) /
ρʲ_prev_level
)
)

u³ʲ_u³ʲ +=
(1 / (J_halflevel^2)) * (
J_prev_level^2 *
2 *
(
(entrʲ_prev_level + turb_entrʲ_prev_level) *
u³⁰_data_prev_halflevel -
(entrʲ_prev_level + turb_entrʲ_prev_level) * u³ʲ_prev_halflevel
)
)

u³ʲ_u³ʲ -=
(1 / (J_halflevel^2)) *
(J_prev_level^2 * 2 * nh_pressure³ʲ_data_prev_halflevel)
return u³ʲ_u³ʲ
end


NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
Y,
p,
Expand Down Expand Up @@ -490,60 +535,19 @@ NVTX.@annotate function set_diagnostic_edmf_precomputed_quantities_do_integral!(
end

u³ʲ_datau³ʲ_data = p.scratch.temp_data_level
# Using constant exponents in broadcasts allocate, so we use
# local_geometry_halflevel.J * local_geometry_halflevel.J instead.
# See ClimaCore.jl issue #1126.
@. u³ʲ_datau³ʲ_data =
(
1 /
(local_geometry_halflevel.J * local_geometry_halflevel.J)
) * (
local_geometry_prev_halflevel.J *
local_geometry_prev_halflevel.J *
u³ʲ_data_prev_halflevel *
u³ʲ_data_prev_halflevel
)

@. u³ʲ_datau³ʲ_data -=
(
1 /
(local_geometry_halflevel.J * local_geometry_halflevel.J)
) * (
local_geometry_prev_level.J *
local_geometry_prev_level.J *
2 *
(
∇Φ³_data_prev_level * (ρʲ_prev_level - ρ_prev_level) /
ρʲ_prev_level
)
)

@. u³ʲ_datau³ʲ_data +=
(
1 /
(local_geometry_halflevel.J * local_geometry_halflevel.J)
) * (
local_geometry_prev_level.J *
local_geometry_prev_level.J *
2 *
(
(entrʲ_prev_level + turb_entrʲ_prev_level) *
u³⁰_data_prev_halflevel -
(entrʲ_prev_level + turb_entrʲ_prev_level) *
u³ʲ_data_prev_halflevel
)
)

@. u³ʲ_datau³ʲ_data -=
(
1 /
(local_geometry_halflevel.J * local_geometry_halflevel.J)
) * (
local_geometry_prev_level.J *
local_geometry_prev_level.J *
2 *
nh_pressure³ʲ_data_prev_halflevel
)
@. u³ʲ_datau³ʲ_data = compute_u³ʲ_u³ʲ(
u³ʲ_prev_halflevel,
local_geometry_prev_halflevel.J,
local_geometry_halflevel.J,
local_geometry_prev_level.J,
∇Φ³_data_prev_level,
ρʲ_prev_level,
ρ_prev_level,
entrʲ_prev_level,
turb_entrʲ_prev_level,
u³⁰_data_prev_halflevel,
nh_pressure³ʲ_data_prev_halflevel,
)

# get u³ʲ to calculate divergence term for detrainment,
# u³ʲ will be clipped later after we get area fraction
Expand Down

0 comments on commit b32468f

Please sign in to comment.