Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform FD broadcast objs to use MArrays #1763

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,20 @@ end
rebuild(data::VF{S, Nv}, array::AbstractArray{T, 2}) where {S, Nv, T} =
VF{S, Nv}(array)

@inline function rebuild_with_MArray(data::VF)
Nv = nlevels(data)
Nf = ncomponents(data)
FT = eltype(parent(data))
# @show Nf
localmem = MArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(undef)
rdata = rebuild(data, localmem)
@inbounds for v in 1:Nv
rdata[v] = data[v]
end
rdata
# return rebuild(data, SArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(parent(data)))
end

function replace_basetype(data::VF{S, Nv}, ::Type{T}) where {S, Nv, T}
array = parent(data)
S′ = replace_basetype(eltype(array), T, S)
Expand Down
1 change: 1 addition & 0 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import ..DataLayouts:
FusedMultiBroadcast,
@fused_direct,
isascalar,
DataColumnStyle,
check_fused_broadcast_axes
import ..Domains
import ..Topologies
Expand Down
4 changes: 3 additions & 1 deletion src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ struct FieldStyle{DS <: DataStyle} <: AbstractFieldStyle end
FieldStyle(::DS) where {DS <: DataStyle} = FieldStyle{DS}()
FieldStyle(x::Base.Broadcast.Unknown) = x

DataColumnStyle(::Type{FieldStyle{DS}}) where {DS} = FieldStyle{DataColumnStyle(DS)}

Base.Broadcast.BroadcastStyle(::Type{Field{V, S}}) where {V, S} =
FieldStyle(DataStyle(V))

Expand Down Expand Up @@ -113,7 +115,7 @@ Base.@propagate_inbounds function column(
) where {Style <: AbstractFieldStyle}
_args = column_args(bc.args, i, j, h)
_axes = column(axes(bc), i, j, h)
Base.Broadcast.Broadcasted{Style}(bc.f, _args, _axes)
Base.Broadcast.Broadcasted{DataColumnStyle(Style)}(bc.f, _args, _axes)
end

# Return underlying DataLayout object, DataStyle of broadcasted
Expand Down
23 changes: 13 additions & 10 deletions src/Grids/column.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,28 @@ end
A view into a column of a `ExtrudedFiniteDifferenceGrid`. This can be used as an
"""
struct ColumnGrid{
G <: AbstractExtrudedFiniteDifferenceGrid,
VG <: FiniteDifferenceGrid,
GG <: Geometry.AbstractGlobalGeometry,
C <: ColumnIndex,
} <: AbstractFiniteDifferenceGrid
full_grid::G
vertical_grid::VG
global_geometry::GG
colidx::C
end

local_geometry_type(::Type{ColumnGrid{G, C}}) where {G, C} =
local_geometry_type(G)
local_geometry_type(::Type{ColumnGrid{VG, C}}) where {VG, C} =
local_geometry_type(VG)

column(grid::AbstractExtrudedFiniteDifferenceGrid, colidx::ColumnIndex) =
ColumnGrid(grid, colidx)
function column(grid::AbstractExtrudedFiniteDifferenceGrid, colidx::ColumnIndex)
ColumnGrid(grid.vertical_grid, grid.global_geometry, colidx)
end

topology(colgrid::ColumnGrid) = vertical_topology(colgrid.full_grid)
vertical_topology(colgrid::ColumnGrid) = vertical_topology(colgrid.full_grid)
topology(colgrid::ColumnGrid) = vertical_topology(colgrid.vertical_grid)
vertical_topology(colgrid::ColumnGrid) = vertical_topology(colgrid.vertical_grid)

local_geometry_data(colgrid::ColumnGrid, staggering::Staggering) = column(
local_geometry_data(colgrid.full_grid, staggering::Staggering),
local_geometry_data(colgrid.vertical_grid, staggering::Staggering),
colgrid.colidx.ij...,
colgrid.colidx.h,
)
global_geometry(colgrid::ColumnGrid) = global_geometry(colgrid.full_grid)
global_geometry(colgrid::ColumnGrid) = colgrid.global_geometry
2 changes: 1 addition & 1 deletion src/Grids/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ abstract type AbstractExtrudedFiniteDifferenceGrid <: AbstractGrid end
Construct an `ExtrudedFiniteDifferenceGrid` from the horizontal and vertical spaces.
"""
mutable struct ExtrudedFiniteDifferenceGrid{
struct ExtrudedFiniteDifferenceGrid{
H <: AbstractGrid,
V <: FiniteDifferenceGrid,
A <: HypsographyAdaption,
Expand Down
2 changes: 1 addition & 1 deletion src/Grids/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ This is an object which contains all the necessary geometric information.
To avoid unnecessary duplication, we memoize the construction of the grid.
"""
mutable struct FiniteDifferenceGrid{
struct FiniteDifferenceGrid{
T <: Topologies.AbstractIntervalTopology,
GG,
CLG,
Expand Down
4 changes: 2 additions & 2 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ abstract type AbstractSpectralElementGrid <: AbstractGrid end
A one-dimensional space: within each element the space is represented as a polynomial.
"""
mutable struct SpectralElementGrid1D{
struct SpectralElementGrid1D{
T,
Q,
GG <: Geometry.AbstractGlobalGeometry,
Expand Down Expand Up @@ -104,7 +104,7 @@ end
A two-dimensional space: within each element the space is represented as a polynomial.
"""
mutable struct SpectralElementGrid2D{
struct SpectralElementGrid2D{
T,
Q,
GG <: Geometry.AbstractGlobalGeometry,
Expand Down
87 changes: 87 additions & 0 deletions src/Operators/check_for_non_column.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
@inline check_for_non_column_args(args::Tuple, inds...) = (
check_for_non_column(args[1], inds...),
check_for_non_column_args(Base.tail(args), inds...)...,
)
@inline check_for_non_column_args(args::Tuple{Any}, inds...) =
(check_for_non_column(args[1], inds...),)
@inline check_for_non_column_args(args::Tuple{}, inds...) = ()

@inline function check_for_non_column(
bc::StencilBroadcasted{Style},
inds...
) where {Style}
StencilBroadcasted{Style}(
bc.op,
check_for_non_column_args(bc.args, inds...),
bc.axes
)
end
@inline function check_for_non_column(
bc::Base.Broadcast.Broadcasted{Style},
inds...
) where {Style}
Base.Broadcast.Broadcasted{Style}(
bc.f,
check_for_non_column_args(bc.args, inds...),
bc.axes
)
end
@inline function check_for_non_column(f::Fields.Field, inds...)
check_for_non_column(Fields.field_values(f), inds...)
return Fields.Field(Fields.field_values(f), axes(f))
end
@inline check_for_non_column(x::Tuple, inds...) =
(check_for_non_column(first(x), inds...),
check_for_non_column(Base.tail(x), inds...)...)
@inline check_for_non_column(x::Tuple{Any}, inds...) =
(check_for_non_column(first(x), inds...),)
@inline check_for_non_column(x::Tuple{}, inds...) = ()

@inline check_for_non_column(x, inds...) = x
@inline check_for_non_column(x::DataLayouts.VIJFH, inds...) = error("Found non-column data $x.")
@inline check_for_non_column(x::DataLayouts.VIFH, inds...) = error("Found non-column data $x.")


# $$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$

@inline check_for_non_column_style_args(args::Tuple, inds...) = (
check_for_non_column_style(args[1], inds...),
check_for_non_column_style_args(Base.tail(args), inds...)...,
)
@inline check_for_non_column_style_args(args::Tuple{Any}, inds...) =
(check_for_non_column_style(args[1], inds...),)
@inline check_for_non_column_style_args(args::Tuple{}, inds...) = ()

@inline function check_for_non_column_style(
bc::StencilBroadcasted{Style},
inds...
) where {Style}
check_for_non_column_style(Style)
StencilBroadcasted{Style}(
bc.op,
check_for_non_column_style_args(bc.args, inds...),
bc.axes
)
end
@inline function check_for_non_column_style(
bc::Base.Broadcast.Broadcasted{Style},
inds...
) where {Style}
check_for_non_column_style(Style)
Base.Broadcast.Broadcasted{Style}(
bc.f,
check_for_non_column_style_args(bc.args, inds...),
bc.axes
)
end
@inline check_for_non_column_style(::Fields.FieldStyle{DS}, inds...) where {DS} = check_for_non_column_style(DS)
@inline check_for_non_column_style(::Type{Fields.FieldStyle{DS}}, inds...) where {DS} = check_for_non_column_style(DS)
@inline check_for_non_column_style(::Type{DS}, inds...) where {DS <: DataLayouts.VIJFHStyle} = error("Found non-column style")
@inline check_for_non_column_style(::Type{DS}, inds...) where {DS <: DataLayouts.VIFHStyle} = error("Found non-column style")
@inline check_for_non_column_style(x::Tuple) =
(check_for_non_column_style(first(x), inds...),
check_for_non_column_style(Base.tail(x), inds...)...)
@inline check_for_non_column_style(x::Tuple{Any}, inds...) =
(check_for_non_column_style(first(x), inds...),)
@inline check_for_non_column_style(x::Tuple{}, inds...) = ()
@inline check_for_non_column_style(x, inds...) = x
169 changes: 169 additions & 0 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,34 @@ StencilBroadcasted{Style}(
) where {Style, Op, Args, Axes} =
StencilBroadcasted{Style, Op, Args, Axes}(op, args, axes)

Base.@propagate_inbounds function column(
bc::StencilBroadcasted{StencilStyle},
inds...,
)
_args = column_args(bc.args, inds...)
_axes = column(axes(bc), inds...)
StencilBroadcasted{ColumnStencilStyle()}(bc.op, _args, _axes)
end

Base.@propagate_inbounds function column(
bc::StencilBroadcasted{ColumnStencilStyle},
inds...,
)
_args = column_args(bc.args, inds...)
_axes = column(axes(bc), inds...)
StencilBroadcasted{ColumnStencilStyle}(bc.op, _args, _axes)
end

Base.@propagate_inbounds function column(
bc::StencilBroadcasted{Style},
inds...,
) where {Style}
error("Uncaught case")
_args = column_args(bc.args, inds...)
_axes = column(axes(bc), inds...)
StencilBroadcasted{Style}(bc.op, _args, _axes)
end

Adapt.adapt_structure(to, sbc::StencilBroadcasted{Style}) where {Style} =
StencilBroadcasted{Style}(
Adapt.adapt(to, sbc.op),
Expand Down Expand Up @@ -3386,6 +3414,81 @@ function window_bounds(space, bc)
return (li, lw, rw, ri)
end

to_MArray_style(::Type{T}) where {T} = T
to_MArray_style(::Type{Fields.FieldStyle{DS}}) where {DS} = Fields.FieldStyle{to_MArray_style(DS)}
to_MArray_style(::Val{Nv}, ::Type{T}) where {Nv, T<:AbstractArray} = MArray{Tuple{Nv}, eltype(T)}

to_MArray_style(::Type{DataLayouts.VFStyle{Nv, A}}) where {Nv, A} = DataLayouts.VFStyle{Nv, to_MArray_style(Val(Nv), A)}
to_MArray_style(::Type{DataLayouts.VIFHStyle{Nv, Ni, A}}) where {Nv, Ni, A} = DataLayouts.VIFHStyle{Nv, Ni, to_MArray_style(Val(Nv), A)}
to_MArray_style(::Type{DataLayouts.VIJFHStyle{Nv, Nij, A}}) where {Nv, Nij, A} = DataLayouts.VIJFHStyle{Nv, Nij, to_MArray_style(Val(Nv), A)}

# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
@inline transform_to_local_mem_args(args::Tuple, hidx, lg_data) = (
transform_to_local_mem(args[1], hidx, lg_data),
transform_to_local_mem_args(Base.tail(args), hidx, lg_data)...,
)
@inline transform_to_local_mem_args(args::Tuple{Any}, hidx, lg_data) =
(transform_to_local_mem(args[1], hidx, lg_data),)
@inline transform_to_local_mem_args(args::Tuple{}, hidx, lg_data) = ()

@inline function transform_to_local_mem(
bc::StencilBroadcasted{ColumnStencilStyle},
hidx, lg_data
)
StencilBroadcasted{ColumnStencilStyle}(
bc.op,
transform_to_local_mem_args(bc.args, hidx, lg_data),
bc.axes
)
end

@inline function transform_to_local_mem(
bc::Base.Broadcast.Broadcasted{Style},
hidx, lg_data
) where {Style}
args = transform_to_local_mem_args(bc.args, hidx, lg_data)
Style_mem = RecursiveApply.rmaptype(to_MArray_style, Style)
Base.Broadcast.Broadcasted{Style_mem}(
bc.f,
args,
bc.axes
)
end
import StaticArrays: MArray
@inline function transform_to_local_mem(data::DataLayouts.DataColumn, hidx, lg_data)
if parent(data) isa MArray
return data
elseif eltype(data) <: Geometry.LocalGeometry # we al
return data
# (ᶠlg, ᶜlg) = lg_data
# if DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶠlg)
# return ᶠlg
# elseif DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶜlg)
# return ᶜlg
# else
# error("oops")
# end
else
return DataLayouts.rebuild_with_MArray(data)
end
end
@inline function transform_to_local_mem(f::Fields.Field, hidx, lg_data)
(ᶠlg, ᶜlg) = lg_data
fdata = Fields.field_values(f)
datacol_lm = transform_to_local_mem(fdata, hidx, lg_data)
return Fields.Field(datacol_lm, axes(f))
end
@inline transform_to_local_mem(x::Tuple, hidx, lg_data) =
(transform_to_local_mem(first(x), hidx, lg_data),
transform_to_local_mem(Base.tail(x), hidx, lg_data)...)
@inline transform_to_local_mem(x::Tuple{Any}, hidx, lg_data) =
(transform_to_local_mem(first(x), hidx, lg_data),)
@inline transform_to_local_mem(x::Tuple{}, hidx, lg_data) = ()

@inline transform_to_local_mem(x, hidx, lg_data) = x

include("check_for_non_column.jl")

Base.@propagate_inbounds function apply_stencil!(
space,
Expand All @@ -3394,6 +3497,72 @@ Base.@propagate_inbounds function apply_stencil!(
hidx,
(li, lw, rw, ri) = window_bounds(space, bc),
)

if true
(i, j, h) = hidx
bc_col = Spaces.column(bc, i,j,h)
ᶠspace = Spaces.FaceExtrudedFiniteDifferenceSpace(space)
ᶜspace = Spaces.CenterExtrudedFiniteDifferenceSpace(space)
ᶠlg_col = Spaces.column(Spaces.local_geometry_data(ᶠspace), i,j,h)
ᶜlg_col = Spaces.column(Spaces.local_geometry_data(ᶜspace), i,j,h)
# ᶠlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶠlg_col)
# ᶜlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶜlg_col)
ᶠlg_col_localmem = nothing
ᶜlg_col_localmem = nothing
lg_data = (ᶠlg_col_localmem, ᶜlg_col_localmem)
bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data)
bc_used = bc_localmem
field_out_used = Fields.column(field_out, i,j,h)
else
bc_used = bc
field_out_used = field_out
end
if !Topologies.isperiodic(Spaces.vertical_topology(space))
# left window
lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
@inbounds for idx in li:(lw - 1)
setidx!(
space,
field_out_used,
idx,
hidx,
getidx(space, bc_used, lbw, idx, hidx),
)
end
end
# interior
@inbounds for idx in lw:rw
setidx!(
space,
field_out_used,
idx,
hidx,
getidx(space, bc_used, Interior(), idx, hidx),
)
end
if !Topologies.isperiodic(Spaces.vertical_topology(space))
# right window
rbw = RightBoundaryWindow{Spaces.right_boundary_name(space)}()
@inbounds for idx in (rw + 1):ri
setidx!(
space,
field_out_used,
idx,
hidx,
getidx(space, bc_used, rbw, idx, hidx),
)
end
end
return field_out
end

Base.@propagate_inbounds function apply_stencil!(
space::Spaces.FiniteDifferenceSpace,
field_out,
bc,
hidx,
(li, lw, rw, ri) = window_bounds(space, bc),
)
if !Topologies.isperiodic(Spaces.vertical_topology(space))
# left window
lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
Expand Down
Loading
Loading