Skip to content

Commit

Permalink
Transform FD broadcast objs to use MArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 31, 2024
1 parent 03320cf commit 7e9b0be
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 2 deletions.
12 changes: 12 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,18 @@ end
rebuild(data::VF{S, Nv}, array::AbstractArray{T, 2}) where {S, Nv, T} =
VF{S, Nv}(array)

@inline function rebuild_with_MArray(data::VF)
Nv = nlevels(data)
Nf = ncomponents(data)
FT = eltype(parent(data))
localmem = MArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(undef)
rdata = rebuild(data, localmem)
@inbounds for v in 1:Nv
rdata[v] = data[v]
end
rdata
end

function replace_basetype(data::VF{S, Nv}, ::Type{T}) where {S, Nv, T}
array = parent(data)
S′ = replace_basetype(eltype(array), T, S)
Expand Down
130 changes: 130 additions & 0 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3364,6 +3364,8 @@ function Base.copyto!(
(Ni, Nj, _, _, Nh) = size(local_geometry)
context = ClimaComms.context(axes(field_out))
device = ClimaComms.device(context)
ᶜspace = Spaces.space(space, Grids.CellCenter())

if (device isa ClimaComms.CPUMultiThreaded) && Nh > 1
return _threaded_copyto!(field_out, bc, Ni, Nj, Nh)
end
Expand All @@ -3386,6 +3388,70 @@ function window_bounds(space, bc)
return (li, lw, rw, ri)
end

# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
@inline transform_to_local_mem_args(args::Tuple, hidx, lg_data) = (
transform_to_local_mem(args[1], hidx, lg_data),
transform_to_local_mem_args(Base.tail(args), hidx, lg_data)...,
)
@inline transform_to_local_mem_args(args::Tuple{Any}, hidx, lg_data) =
(transform_to_local_mem(args[1], hidx, lg_data),)
@inline transform_to_local_mem_args(args::Tuple{}, hidx, lg_data) = ()

@inline function transform_to_local_mem(
bc::StencilBroadcasted{ColumnStencilStyle},
hidx, lg_data
)
StencilBroadcasted{ColumnStencilStyle}(
bc.op,
transform_to_local_mem_args(bc.args, hidx, lg_data),
bc.axes
)
end

@inline function transform_to_local_mem(
bc::Base.Broadcast.Broadcasted,
hidx, lg_data
)
args = transform_to_local_mem_args(bc.args, hidx, lg_data)
Base.Broadcast.Broadcasted(
bc.f,
args,
bc.axes
)
end
import StaticArrays: MArray
@inline function transform_to_local_mem(data::DataLayouts.DataColumn, hidx, lg_data)
if eltype(data) <: Geometry.LocalGeometry # we al
(ᶠlg, ᶜlg) = lg_data
if DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶠlg)
return ᶠlg
elseif DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶜlg)
return ᶜlg
else
error("oops")
end
elseif parent(data) isa MArray
return data
else
return DataLayouts.rebuild_with_MArray(data)
end
end
@inline function transform_to_local_mem(f::Fields.Field, hidx, lg_data)
(ᶠlg, ᶜlg) = lg_data
fdata = Fields.field_values(f)
datacol_lm = transform_to_local_mem(fdata, hidx, lg_data)
return Fields.Field(datacol_lm, axes(f))
end
@inline transform_to_local_mem(x::Tuple, hidx, lg_data) =
(transform_to_local_mem(first(x), hidx, lg_data),
transform_to_local_mem(Base.tail(x), hidx, lg_data)...)
@inline transform_to_local_mem(x::Tuple{Any}, hidx, lg_data) =
(transform_to_local_mem(first(x), hidx, lg_data),)
@inline transform_to_local_mem(x::Tuple{}, hidx, lg_data) = ()

@inline transform_to_local_mem(x, hidx, lg_data) = x
@inline transform_to_local_mem(x::DataLayouts.VIJFH, hidx, lg_data) = error("Data $x was not columnized.")

Base.@propagate_inbounds function apply_stencil!(
space,
Expand All @@ -3394,6 +3460,70 @@ Base.@propagate_inbounds function apply_stencil!(
hidx,
(li, lw, rw, ri) = window_bounds(space, bc),
)

(i, j, h) = hidx
bc_col = Spaces.column(bc, i,j,h)
ᶠspace = Spaces.FaceExtrudedFiniteDifferenceSpace(space)
ᶜspace = Spaces.CenterExtrudedFiniteDifferenceSpace(space)
ᶠlg_col = Spaces.column(Spaces.local_geometry_data(ᶠspace), i,j,h)
ᶜlg_col = Spaces.column(Spaces.local_geometry_data(ᶜspace), i,j,h)
ᶠlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶠlg_col)
ᶜlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶜlg_col)
lg_data = (ᶠlg_col_localmem, ᶜlg_col_localmem)

try
bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data)
catch
@show bc_col
bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data)
end
field_out_col = Fields.column(field_out, i,j,h)
if !Topologies.isperiodic(Spaces.vertical_topology(space))
# left window
lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
@inbounds for idx in li:(lw - 1)
setidx!(
space,
field_out_col,
idx,
hidx,
getidx(space, bc_localmem, lbw, idx, hidx),
)
end
end
# interior
@inbounds for idx in lw:rw
setidx!(
space,
field_out_col,
idx,
hidx,
getidx(space, bc_localmem, Interior(), idx, hidx),
)
end
if !Topologies.isperiodic(Spaces.vertical_topology(space))
# right window
rbw = RightBoundaryWindow{Spaces.right_boundary_name(space)}()
@inbounds for idx in (rw + 1):ri
setidx!(
space,
field_out_col,
idx,
hidx,
getidx(space, bc_localmem, rbw, idx, hidx),
)
end
end
return field_out
end

Base.@propagate_inbounds function apply_stencil!(
space::Spaces.FiniteDifferenceSpace,
field_out,
bc,
hidx,
(li, lw, rw, ri) = window_bounds(space, bc),
)
if !Topologies.isperiodic(Spaces.vertical_topology(space))
# left window
lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
Expand Down
4 changes: 2 additions & 2 deletions test/Operators/finitedifference/opt_examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ function set_ᶠuₕ³!(ᶜx, ᶠx)
@. ᶠx.ᶠuₕ³ = ᶠwinterp(ᶜx.ρ * ᶜJ, CT3(ᶜx.uₕ))
return nothing
end
@testset "Inference/allocations when broadcasting types" begin
# @testset "Inference/allocations when broadcasting types" begin
FT = Float64
cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT; zelem = 25, helem = 10)
fspace = Spaces.FaceExtrudedFiniteDifferenceSpace(cspace)
Expand All @@ -583,4 +583,4 @@ end
@benchmark set_ᶠuₕ³!($ ᶜx, $ᶠx)
end
show(stdout, MIME("text/plain"), trial)
end
# end

0 comments on commit 7e9b0be

Please sign in to comment.