Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 31, 2024
1 parent 3b779dd commit 6ca3a39
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 171 deletions.
2 changes: 2 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -921,12 +921,14 @@ rebuild(data::VF{S, Nv}, array::AbstractArray{T, 2}) where {S, Nv, T} =
Nv = nlevels(data)
Nf = ncomponents(data)
FT = eltype(parent(data))
# @show Nf
localmem = MArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(undef)
rdata = rebuild(data, localmem)
@inbounds for v in 1:Nv
rdata[v] = data[v]
end
rdata
# return rebuild(data, SArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(parent(data)))
end

function replace_basetype(data::VF{S, Nv}, ::Type{T}) where {S, Nv, T}
Expand Down
1 change: 1 addition & 0 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import ..DataLayouts:
FusedMultiBroadcast,
@fused_direct,
isascalar,
DataColumnStyle,
check_fused_broadcast_axes
import ..Domains
import ..Topologies
Expand Down
4 changes: 3 additions & 1 deletion src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ struct FieldStyle{DS <: DataStyle} <: AbstractFieldStyle end
FieldStyle(::DS) where {DS <: DataStyle} = FieldStyle{DS}()
FieldStyle(x::Base.Broadcast.Unknown) = x

DataColumnStyle(::Type{FieldStyle{DS}}) where {DS} = FieldStyle{DataColumnStyle(DS)}

Base.Broadcast.BroadcastStyle(::Type{Field{V, S}}) where {V, S} =
FieldStyle(DataStyle(V))

Expand Down Expand Up @@ -113,7 +115,7 @@ Base.@propagate_inbounds function column(
) where {Style <: AbstractFieldStyle}
_args = column_args(bc.args, i, j, h)
_axes = column(axes(bc), i, j, h)
Base.Broadcast.Broadcasted{Style}(bc.f, _args, _axes)
Base.Broadcast.Broadcasted{DataColumnStyle(Style)}(bc.f, _args, _axes)
end

# Return underlying DataLayout object, DataStyle of broadcasted
Expand Down
23 changes: 13 additions & 10 deletions src/Grids/column.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,28 @@ end
A view into a column of a `ExtrudedFiniteDifferenceGrid`. This can be used as an
"""
struct ColumnGrid{
G <: AbstractExtrudedFiniteDifferenceGrid,
VG <: FiniteDifferenceGrid,
GG <: Geometry.AbstractGlobalGeometry,
C <: ColumnIndex,
} <: AbstractFiniteDifferenceGrid
full_grid::G
vertical_grid::VG
global_geometry::GG
colidx::C
end

local_geometry_type(::Type{ColumnGrid{G, C}}) where {G, C} =
local_geometry_type(G)
local_geometry_type(::Type{ColumnGrid{VG, C}}) where {VG, C} =
local_geometry_type(VG)

column(grid::AbstractExtrudedFiniteDifferenceGrid, colidx::ColumnIndex) =
ColumnGrid(grid, colidx)
function column(grid::AbstractExtrudedFiniteDifferenceGrid, colidx::ColumnIndex)
ColumnGrid(grid.vertical_grid, grid.global_geometry, colidx)
end

topology(colgrid::ColumnGrid) = vertical_topology(colgrid.full_grid)
vertical_topology(colgrid::ColumnGrid) = vertical_topology(colgrid.full_grid)
topology(colgrid::ColumnGrid) = vertical_topology(colgrid.vertical_grid)
vertical_topology(colgrid::ColumnGrid) = vertical_topology(colgrid.vertical_grid)

local_geometry_data(colgrid::ColumnGrid, staggering::Staggering) = column(
local_geometry_data(colgrid.full_grid, staggering::Staggering),
local_geometry_data(colgrid.vertical_grid, staggering::Staggering),
colgrid.colidx.ij...,
colgrid.colidx.h,
)
global_geometry(colgrid::ColumnGrid) = global_geometry(colgrid.full_grid)
global_geometry(colgrid::ColumnGrid) = colgrid.global_geometry
2 changes: 1 addition & 1 deletion src/Grids/extruded.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ abstract type AbstractExtrudedFiniteDifferenceGrid <: AbstractGrid end
Construct an `ExtrudedFiniteDifferenceGrid` from the horizontal and vertical spaces.
"""
mutable struct ExtrudedFiniteDifferenceGrid{
struct ExtrudedFiniteDifferenceGrid{
H <: AbstractGrid,
V <: FiniteDifferenceGrid,
A <: HypsographyAdaption,
Expand Down
2 changes: 1 addition & 1 deletion src/Grids/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ This is an object which contains all the necessary geometric information.
To avoid unnecessary duplication, we memoize the construction of the grid.
"""
mutable struct FiniteDifferenceGrid{
struct FiniteDifferenceGrid{
T <: Topologies.AbstractIntervalTopology,
GG,
CLG,
Expand Down
4 changes: 2 additions & 2 deletions src/Grids/spectralelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ abstract type AbstractSpectralElementGrid <: AbstractGrid end
A one-dimensional space: within each element the space is represented as a polynomial.
"""
mutable struct SpectralElementGrid1D{
struct SpectralElementGrid1D{
T,
Q,
GG <: Geometry.AbstractGlobalGeometry,
Expand Down Expand Up @@ -104,7 +104,7 @@ end
A two-dimensional space: within each element the space is represented as a polynomial.
"""
mutable struct SpectralElementGrid2D{
struct SpectralElementGrid2D{
T,
Q,
GG <: Geometry.AbstractGlobalGeometry,
Expand Down
87 changes: 87 additions & 0 deletions src/Operators/check_for_non_column.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
@inline check_for_non_column_args(args::Tuple, inds...) = (
check_for_non_column(args[1], inds...),
check_for_non_column_args(Base.tail(args), inds...)...,
)
@inline check_for_non_column_args(args::Tuple{Any}, inds...) =
(check_for_non_column(args[1], inds...),)
@inline check_for_non_column_args(args::Tuple{}, inds...) = ()

@inline function check_for_non_column(
bc::StencilBroadcasted{Style},
inds...
) where {Style}
StencilBroadcasted{Style}(
bc.op,
check_for_non_column_args(bc.args, inds...),
bc.axes
)
end
@inline function check_for_non_column(
bc::Base.Broadcast.Broadcasted{Style},
inds...
) where {Style}
Base.Broadcast.Broadcasted{Style}(
bc.f,
check_for_non_column_args(bc.args, inds...),
bc.axes
)
end
@inline function check_for_non_column(f::Fields.Field, inds...)
check_for_non_column(Fields.field_values(f), inds...)
return Fields.Field(Fields.field_values(f), axes(f))
end
@inline check_for_non_column(x::Tuple, inds...) =
(check_for_non_column(first(x), inds...),
check_for_non_column(Base.tail(x), inds...)...)
@inline check_for_non_column(x::Tuple{Any}, inds...) =
(check_for_non_column(first(x), inds...),)
@inline check_for_non_column(x::Tuple{}, inds...) = ()

@inline check_for_non_column(x, inds...) = x
@inline check_for_non_column(x::DataLayouts.VIJFH, inds...) = error("Found non-column data $x.")
@inline check_for_non_column(x::DataLayouts.VIFH, inds...) = error("Found non-column data $x.")


# $$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$

@inline check_for_non_column_style_args(args::Tuple, inds...) = (
check_for_non_column_style(args[1], inds...),
check_for_non_column_style_args(Base.tail(args), inds...)...,
)
@inline check_for_non_column_style_args(args::Tuple{Any}, inds...) =
(check_for_non_column_style(args[1], inds...),)
@inline check_for_non_column_style_args(args::Tuple{}, inds...) = ()

@inline function check_for_non_column_style(
bc::StencilBroadcasted{Style},
inds...
) where {Style}
check_for_non_column_style(Style)
StencilBroadcasted{Style}(
bc.op,
check_for_non_column_style_args(bc.args, inds...),
bc.axes
)
end
@inline function check_for_non_column_style(
bc::Base.Broadcast.Broadcasted{Style},
inds...
) where {Style}
check_for_non_column_style(Style)
Base.Broadcast.Broadcasted{Style}(
bc.f,
check_for_non_column_style_args(bc.args, inds...),
bc.axes
)
end
@inline check_for_non_column_style(::Fields.FieldStyle{DS}, inds...) where {DS} = check_for_non_column_style(DS)
@inline check_for_non_column_style(::Type{Fields.FieldStyle{DS}}, inds...) where {DS} = check_for_non_column_style(DS)
@inline check_for_non_column_style(::Type{DS}, inds...) where {DS <: DataLayouts.VIJFHStyle} = error("Found non-column style")
@inline check_for_non_column_style(::Type{DS}, inds...) where {DS <: DataLayouts.VIFHStyle} = error("Found non-column style")
@inline check_for_non_column_style(x::Tuple) =
(check_for_non_column_style(first(x), inds...),
check_for_non_column_style(Base.tail(x), inds...)...)
@inline check_for_non_column_style(x::Tuple{Any}, inds...) =
(check_for_non_column_style(first(x), inds...),)
@inline check_for_non_column_style(x::Tuple{}, inds...) = ()
@inline check_for_non_column_style(x, inds...) = x
111 changes: 76 additions & 35 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,34 @@ StencilBroadcasted{Style}(
) where {Style, Op, Args, Axes} =
StencilBroadcasted{Style, Op, Args, Axes}(op, args, axes)

Base.@propagate_inbounds function column(
bc::StencilBroadcasted{StencilStyle},
inds...,
)
_args = column_args(bc.args, inds...)
_axes = column(axes(bc), inds...)
StencilBroadcasted{ColumnStencilStyle()}(bc.op, _args, _axes)
end

Base.@propagate_inbounds function column(
bc::StencilBroadcasted{ColumnStencilStyle},
inds...,
)
_args = column_args(bc.args, inds...)
_axes = column(axes(bc), inds...)
StencilBroadcasted{ColumnStencilStyle}(bc.op, _args, _axes)
end

Base.@propagate_inbounds function column(
bc::StencilBroadcasted{Style},
inds...,
) where {Style}
error("Uncaught case")
_args = column_args(bc.args, inds...)
_axes = column(axes(bc), inds...)
StencilBroadcasted{Style}(bc.op, _args, _axes)
end

Adapt.adapt_structure(to, sbc::StencilBroadcasted{Style}) where {Style} =
StencilBroadcasted{Style}(
Adapt.adapt(to, sbc.op),
Expand Down Expand Up @@ -3386,6 +3414,14 @@ function window_bounds(space, bc)
return (li, lw, rw, ri)
end

to_MArray_style(::Type{T}) where {T} = T
to_MArray_style(::Type{Fields.FieldStyle{DS}}) where {DS} = Fields.FieldStyle{to_MArray_style(DS)}
to_MArray_style(::Val{Nv}, ::Type{T}) where {Nv, T<:AbstractArray} = MArray{Tuple{Nv}, eltype(T)}

to_MArray_style(::Type{DataLayouts.VFStyle{Nv, A}}) where {Nv, A} = DataLayouts.VFStyle{Nv, to_MArray_style(Val(Nv), A)}
to_MArray_style(::Type{DataLayouts.VIFHStyle{Nv, Ni, A}}) where {Nv, Ni, A} = DataLayouts.VIFHStyle{Nv, Ni, to_MArray_style(Val(Nv), A)}
to_MArray_style(::Type{DataLayouts.VIJFHStyle{Nv, Nij, A}}) where {Nv, Nij, A} = DataLayouts.VIJFHStyle{Nv, Nij, to_MArray_style(Val(Nv), A)}

# Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
@inline transform_to_local_mem_args(args::Tuple, hidx, lg_data) = (
Expand All @@ -3408,29 +3444,31 @@ end
end

@inline function transform_to_local_mem(
bc::Base.Broadcast.Broadcasted,
bc::Base.Broadcast.Broadcasted{Style},
hidx, lg_data
)
) where {Style}
args = transform_to_local_mem_args(bc.args, hidx, lg_data)
Base.Broadcast.Broadcasted(
Style_mem = RecursiveApply.rmaptype(to_MArray_style, Style)
Base.Broadcast.Broadcasted{Style_mem}(
bc.f,
args,
bc.axes
)
end
import StaticArrays: MArray
@inline function transform_to_local_mem(data::DataLayouts.DataColumn, hidx, lg_data)
if eltype(data) <: Geometry.LocalGeometry # we al
(ᶠlg, ᶜlg) = lg_data
if DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶠlg)
return ᶠlg
elseif DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶜlg)
return ᶜlg
else
error("oops")
end
elseif parent(data) isa MArray
if parent(data) isa MArray
return data
elseif eltype(data) <: Geometry.LocalGeometry # we al
return data
# (ᶠlg, ᶜlg) = lg_data
# if DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶠlg)
# return ᶠlg
# elseif DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶜlg)
# return ᶜlg
# else
# error("oops")
# end
else
return DataLayouts.rebuild_with_MArray(data)
end
Expand All @@ -3449,7 +3487,8 @@ end
@inline transform_to_local_mem(x::Tuple{}, hidx, lg_data) = ()

@inline transform_to_local_mem(x, hidx, lg_data) = x
@inline transform_to_local_mem(x::DataLayouts.VIJFH, hidx, lg_data) = error("Data $x was not columnized.")

include("check_for_non_column.jl")

Base.@propagate_inbounds function apply_stencil!(
space,
Expand All @@ -3459,44 +3498,46 @@ Base.@propagate_inbounds function apply_stencil!(
(li, lw, rw, ri) = window_bounds(space, bc),
)

(i, j, h) = hidx
bc_col = Spaces.column(bc, i,j,h)
ᶠspace = Spaces.FaceExtrudedFiniteDifferenceSpace(space)
ᶜspace = Spaces.CenterExtrudedFiniteDifferenceSpace(space)
ᶠlg_col = Spaces.column(Spaces.local_geometry_data(ᶠspace), i,j,h)
ᶜlg_col = Spaces.column(Spaces.local_geometry_data(ᶜspace), i,j,h)
ᶠlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶠlg_col)
ᶜlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶜlg_col)
lg_data = (ᶠlg_col_localmem, ᶜlg_col_localmem)

try
bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data)
catch
@show bc_col
if true
(i, j, h) = hidx
bc_col = Spaces.column(bc, i,j,h)
ᶠspace = Spaces.FaceExtrudedFiniteDifferenceSpace(space)
ᶜspace = Spaces.CenterExtrudedFiniteDifferenceSpace(space)
ᶠlg_col = Spaces.column(Spaces.local_geometry_data(ᶠspace), i,j,h)
ᶜlg_col = Spaces.column(Spaces.local_geometry_data(ᶜspace), i,j,h)
# ᶠlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶠlg_col)
# ᶜlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶜlg_col)
ᶠlg_col_localmem = nothing
ᶜlg_col_localmem = nothing
lg_data = (ᶠlg_col_localmem, ᶜlg_col_localmem)
bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data)
bc_used = bc_localmem
field_out_used = Fields.column(field_out, i,j,h)
else
bc_used = bc
field_out_used = field_out
end
field_out_col = Fields.column(field_out, i,j,h)
if !Topologies.isperiodic(Spaces.vertical_topology(space))
# left window
lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}()
@inbounds for idx in li:(lw - 1)
setidx!(
space,
field_out_col,
field_out_used,
idx,
hidx,
getidx(space, bc_localmem, lbw, idx, hidx),
getidx(space, bc_used, lbw, idx, hidx),
)
end
end
# interior
@inbounds for idx in lw:rw
setidx!(
space,
field_out_col,
field_out_used,
idx,
hidx,
getidx(space, bc_localmem, Interior(), idx, hidx),
getidx(space, bc_used, Interior(), idx, hidx),
)
end
if !Topologies.isperiodic(Spaces.vertical_topology(space))
Expand All @@ -3505,10 +3546,10 @@ Base.@propagate_inbounds function apply_stencil!(
@inbounds for idx in (rw + 1):ri
setidx!(
space,
field_out_col,
field_out_used,
idx,
hidx,
getidx(space, bc_localmem, rbw, idx, hidx),
getidx(space, bc_used, rbw, idx, hidx),
)
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/Topologies/topology2d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Internally, we can refer to elements in several different ways:
- `ridx`: "receive index": an index into the receive buffer of a ghost element.
- `recv_elem_gidx[ridx] == gidx`
"""
mutable struct Topology2D{
struct Topology2D{
C <: ClimaComms.AbstractCommsContext,
M <: Meshes.AbstractMesh{2},
EO,
Expand Down
Loading

0 comments on commit 6ca3a39

Please sign in to comment.