From 6ca3a39d46ec9219a700997eb7091b8126c02154 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 31 May 2024 14:00:24 -0400 Subject: [PATCH] wip --- src/DataLayouts/DataLayouts.jl | 2 + src/Fields/Fields.jl | 1 + src/Fields/broadcast.jl | 4 +- src/Grids/column.jl | 23 +- src/Grids/extruded.jl | 2 +- src/Grids/finitedifference.jl | 2 +- src/Grids/spectralelement.jl | 4 +- src/Operators/check_for_non_column.jl | 87 +++++++ src/Operators/finitedifference.jl | 111 +++++--- src/Topologies/topology2d.jl | 2 +- .../finitedifference/opt_examples.jl | 245 +++++++++--------- 11 files changed, 312 insertions(+), 171 deletions(-) create mode 100644 src/Operators/check_for_non_column.jl diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index db48b2ad93..146670c518 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -921,12 +921,14 @@ rebuild(data::VF{S, Nv}, array::AbstractArray{T, 2}) where {S, Nv, T} = Nv = nlevels(data) Nf = ncomponents(data) FT = eltype(parent(data)) + # @show Nf localmem = MArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(undef) rdata = rebuild(data, localmem) @inbounds for v in 1:Nv rdata[v] = data[v] end rdata + # return rebuild(data, SArray{Tuple{Nv, Nf}, FT, 2, Nv * Nf}(parent(data))) end function replace_basetype(data::VF{S, Nv}, ::Type{T}) where {S, Nv, T} diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index 992b1f5aa0..4aa3f89806 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -10,6 +10,7 @@ import ..DataLayouts: FusedMultiBroadcast, @fused_direct, isascalar, + DataColumnStyle, check_fused_broadcast_axes import ..Domains import ..Topologies diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index b5c3f0cfcd..208558e34c 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -15,6 +15,8 @@ struct FieldStyle{DS <: DataStyle} <: AbstractFieldStyle end FieldStyle(::DS) where {DS <: DataStyle} = FieldStyle{DS}() FieldStyle(x::Base.Broadcast.Unknown) = x +DataColumnStyle(::Type{FieldStyle{DS}}) where {DS} = FieldStyle{DataColumnStyle(DS)} + Base.Broadcast.BroadcastStyle(::Type{Field{V, S}}) where {V, S} = FieldStyle(DataStyle(V)) @@ -113,7 +115,7 @@ Base.@propagate_inbounds function column( ) where {Style <: AbstractFieldStyle} _args = column_args(bc.args, i, j, h) _axes = column(axes(bc), i, j, h) - Base.Broadcast.Broadcasted{Style}(bc.f, _args, _axes) + Base.Broadcast.Broadcasted{DataColumnStyle(Style)}(bc.f, _args, _axes) end # Return underlying DataLayout object, DataStyle of broadcasted diff --git a/src/Grids/column.jl b/src/Grids/column.jl index d31825f6cb..8b3c8c81f3 100644 --- a/src/Grids/column.jl +++ b/src/Grids/column.jl @@ -29,25 +29,28 @@ end A view into a column of a `ExtrudedFiniteDifferenceGrid`. This can be used as an """ struct ColumnGrid{ - G <: AbstractExtrudedFiniteDifferenceGrid, + VG <: FiniteDifferenceGrid, + GG <: Geometry.AbstractGlobalGeometry, C <: ColumnIndex, } <: AbstractFiniteDifferenceGrid - full_grid::G + vertical_grid::VG + global_geometry::GG colidx::C end -local_geometry_type(::Type{ColumnGrid{G, C}}) where {G, C} = - local_geometry_type(G) +local_geometry_type(::Type{ColumnGrid{VG, C}}) where {VG, C} = + local_geometry_type(VG) -column(grid::AbstractExtrudedFiniteDifferenceGrid, colidx::ColumnIndex) = - ColumnGrid(grid, colidx) +function column(grid::AbstractExtrudedFiniteDifferenceGrid, colidx::ColumnIndex) + ColumnGrid(grid.vertical_grid, grid.global_geometry, colidx) +end -topology(colgrid::ColumnGrid) = vertical_topology(colgrid.full_grid) -vertical_topology(colgrid::ColumnGrid) = vertical_topology(colgrid.full_grid) +topology(colgrid::ColumnGrid) = vertical_topology(colgrid.vertical_grid) +vertical_topology(colgrid::ColumnGrid) = vertical_topology(colgrid.vertical_grid) local_geometry_data(colgrid::ColumnGrid, staggering::Staggering) = column( - local_geometry_data(colgrid.full_grid, staggering::Staggering), + local_geometry_data(colgrid.vertical_grid, staggering::Staggering), colgrid.colidx.ij..., colgrid.colidx.h, ) -global_geometry(colgrid::ColumnGrid) = global_geometry(colgrid.full_grid) +global_geometry(colgrid::ColumnGrid) = colgrid.global_geometry diff --git a/src/Grids/extruded.jl b/src/Grids/extruded.jl index 15b68c9423..d4f97825cd 100644 --- a/src/Grids/extruded.jl +++ b/src/Grids/extruded.jl @@ -22,7 +22,7 @@ abstract type AbstractExtrudedFiniteDifferenceGrid <: AbstractGrid end Construct an `ExtrudedFiniteDifferenceGrid` from the horizontal and vertical spaces. """ -mutable struct ExtrudedFiniteDifferenceGrid{ +struct ExtrudedFiniteDifferenceGrid{ H <: AbstractGrid, V <: FiniteDifferenceGrid, A <: HypsographyAdaption, diff --git a/src/Grids/finitedifference.jl b/src/Grids/finitedifference.jl index e1a991343a..a5fd04fe3b 100644 --- a/src/Grids/finitedifference.jl +++ b/src/Grids/finitedifference.jl @@ -29,7 +29,7 @@ This is an object which contains all the necessary geometric information. To avoid unnecessary duplication, we memoize the construction of the grid. """ -mutable struct FiniteDifferenceGrid{ +struct FiniteDifferenceGrid{ T <: Topologies.AbstractIntervalTopology, GG, CLG, diff --git a/src/Grids/spectralelement.jl b/src/Grids/spectralelement.jl index 1cf8e3178a..a0aeb326b8 100644 --- a/src/Grids/spectralelement.jl +++ b/src/Grids/spectralelement.jl @@ -7,7 +7,7 @@ abstract type AbstractSpectralElementGrid <: AbstractGrid end A one-dimensional space: within each element the space is represented as a polynomial. """ -mutable struct SpectralElementGrid1D{ +struct SpectralElementGrid1D{ T, Q, GG <: Geometry.AbstractGlobalGeometry, @@ -104,7 +104,7 @@ end A two-dimensional space: within each element the space is represented as a polynomial. """ -mutable struct SpectralElementGrid2D{ +struct SpectralElementGrid2D{ T, Q, GG <: Geometry.AbstractGlobalGeometry, diff --git a/src/Operators/check_for_non_column.jl b/src/Operators/check_for_non_column.jl new file mode 100644 index 0000000000..630ded2007 --- /dev/null +++ b/src/Operators/check_for_non_column.jl @@ -0,0 +1,87 @@ +@inline check_for_non_column_args(args::Tuple, inds...) = ( + check_for_non_column(args[1], inds...), + check_for_non_column_args(Base.tail(args), inds...)..., +) +@inline check_for_non_column_args(args::Tuple{Any}, inds...) = + (check_for_non_column(args[1], inds...),) +@inline check_for_non_column_args(args::Tuple{}, inds...) = () + +@inline function check_for_non_column( + bc::StencilBroadcasted{Style}, + inds... +) where {Style} + StencilBroadcasted{Style}( + bc.op, + check_for_non_column_args(bc.args, inds...), + bc.axes + ) +end +@inline function check_for_non_column( + bc::Base.Broadcast.Broadcasted{Style}, + inds... +) where {Style} + Base.Broadcast.Broadcasted{Style}( + bc.f, + check_for_non_column_args(bc.args, inds...), + bc.axes + ) +end +@inline function check_for_non_column(f::Fields.Field, inds...) + check_for_non_column(Fields.field_values(f), inds...) + return Fields.Field(Fields.field_values(f), axes(f)) +end +@inline check_for_non_column(x::Tuple, inds...) = + (check_for_non_column(first(x), inds...), + check_for_non_column(Base.tail(x), inds...)...) +@inline check_for_non_column(x::Tuple{Any}, inds...) = + (check_for_non_column(first(x), inds...),) +@inline check_for_non_column(x::Tuple{}, inds...) = () + +@inline check_for_non_column(x, inds...) = x +@inline check_for_non_column(x::DataLayouts.VIJFH, inds...) = error("Found non-column data $x.") +@inline check_for_non_column(x::DataLayouts.VIFH, inds...) = error("Found non-column data $x.") + + +# $$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$ + +@inline check_for_non_column_style_args(args::Tuple, inds...) = ( + check_for_non_column_style(args[1], inds...), + check_for_non_column_style_args(Base.tail(args), inds...)..., +) +@inline check_for_non_column_style_args(args::Tuple{Any}, inds...) = + (check_for_non_column_style(args[1], inds...),) +@inline check_for_non_column_style_args(args::Tuple{}, inds...) = () + +@inline function check_for_non_column_style( + bc::StencilBroadcasted{Style}, + inds... +) where {Style} + check_for_non_column_style(Style) + StencilBroadcasted{Style}( + bc.op, + check_for_non_column_style_args(bc.args, inds...), + bc.axes + ) +end +@inline function check_for_non_column_style( + bc::Base.Broadcast.Broadcasted{Style}, + inds... +) where {Style} + check_for_non_column_style(Style) + Base.Broadcast.Broadcasted{Style}( + bc.f, + check_for_non_column_style_args(bc.args, inds...), + bc.axes + ) +end +@inline check_for_non_column_style(::Fields.FieldStyle{DS}, inds...) where {DS} = check_for_non_column_style(DS) +@inline check_for_non_column_style(::Type{Fields.FieldStyle{DS}}, inds...) where {DS} = check_for_non_column_style(DS) +@inline check_for_non_column_style(::Type{DS}, inds...) where {DS <: DataLayouts.VIJFHStyle} = error("Found non-column style") +@inline check_for_non_column_style(::Type{DS}, inds...) where {DS <: DataLayouts.VIFHStyle} = error("Found non-column style") +@inline check_for_non_column_style(x::Tuple) = + (check_for_non_column_style(first(x), inds...), + check_for_non_column_style(Base.tail(x), inds...)...) +@inline check_for_non_column_style(x::Tuple{Any}, inds...) = + (check_for_non_column_style(first(x), inds...),) +@inline check_for_non_column_style(x::Tuple{}, inds...) = () +@inline check_for_non_column_style(x, inds...) = x diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index aa8e83e0d2..921145c71a 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -233,6 +233,34 @@ StencilBroadcasted{Style}( ) where {Style, Op, Args, Axes} = StencilBroadcasted{Style, Op, Args, Axes}(op, args, axes) +Base.@propagate_inbounds function column( + bc::StencilBroadcasted{StencilStyle}, + inds..., +) + _args = column_args(bc.args, inds...) + _axes = column(axes(bc), inds...) + StencilBroadcasted{ColumnStencilStyle()}(bc.op, _args, _axes) +end + +Base.@propagate_inbounds function column( + bc::StencilBroadcasted{ColumnStencilStyle}, + inds..., +) + _args = column_args(bc.args, inds...) + _axes = column(axes(bc), inds...) + StencilBroadcasted{ColumnStencilStyle}(bc.op, _args, _axes) +end + +Base.@propagate_inbounds function column( + bc::StencilBroadcasted{Style}, + inds..., +) where {Style} + error("Uncaught case") + _args = column_args(bc.args, inds...) + _axes = column(axes(bc), inds...) + StencilBroadcasted{Style}(bc.op, _args, _axes) +end + Adapt.adapt_structure(to, sbc::StencilBroadcasted{Style}) where {Style} = StencilBroadcasted{Style}( Adapt.adapt(to, sbc.op), @@ -3386,6 +3414,14 @@ function window_bounds(space, bc) return (li, lw, rw, ri) end +to_MArray_style(::Type{T}) where {T} = T +to_MArray_style(::Type{Fields.FieldStyle{DS}}) where {DS} = Fields.FieldStyle{to_MArray_style(DS)} +to_MArray_style(::Val{Nv}, ::Type{T}) where {Nv, T<:AbstractArray} = MArray{Tuple{Nv}, eltype(T)} + +to_MArray_style(::Type{DataLayouts.VFStyle{Nv, A}}) where {Nv, A} = DataLayouts.VFStyle{Nv, to_MArray_style(Val(Nv), A)} +to_MArray_style(::Type{DataLayouts.VIFHStyle{Nv, Ni, A}}) where {Nv, Ni, A} = DataLayouts.VIFHStyle{Nv, Ni, to_MArray_style(Val(Nv), A)} +to_MArray_style(::Type{DataLayouts.VIJFHStyle{Nv, Nij, A}}) where {Nv, Nij, A} = DataLayouts.VIJFHStyle{Nv, Nij, to_MArray_style(Val(Nv), A)} + # Recursively call transform_bc_args() on broadcast arguments in a way that is statically reducible by the optimizer # see Base.Broadcast.preprocess_args @inline transform_to_local_mem_args(args::Tuple, hidx, lg_data) = ( @@ -3408,11 +3444,12 @@ end end @inline function transform_to_local_mem( - bc::Base.Broadcast.Broadcasted, + bc::Base.Broadcast.Broadcasted{Style}, hidx, lg_data -) +) where {Style} args = transform_to_local_mem_args(bc.args, hidx, lg_data) - Base.Broadcast.Broadcasted( + Style_mem = RecursiveApply.rmaptype(to_MArray_style, Style) + Base.Broadcast.Broadcasted{Style_mem}( bc.f, args, bc.axes @@ -3420,17 +3457,18 @@ end end import StaticArrays: MArray @inline function transform_to_local_mem(data::DataLayouts.DataColumn, hidx, lg_data) - if eltype(data) <: Geometry.LocalGeometry # we al - (ᶠlg, ᶜlg) = lg_data - if DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶠlg) - return ᶠlg - elseif DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶜlg) - return ᶜlg - else - error("oops") - end - elseif parent(data) isa MArray + if parent(data) isa MArray return data + elseif eltype(data) <: Geometry.LocalGeometry # we al + return data + # (ᶠlg, ᶜlg) = lg_data + # if DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶠlg) + # return ᶠlg + # elseif DataLayouts.nlevels(data) == DataLayouts.nlevels(ᶜlg) + # return ᶜlg + # else + # error("oops") + # end else return DataLayouts.rebuild_with_MArray(data) end @@ -3449,7 +3487,8 @@ end @inline transform_to_local_mem(x::Tuple{}, hidx, lg_data) = () @inline transform_to_local_mem(x, hidx, lg_data) = x -@inline transform_to_local_mem(x::DataLayouts.VIJFH, hidx, lg_data) = error("Data $x was not columnized.") + +include("check_for_non_column.jl") Base.@propagate_inbounds function apply_stencil!( space, @@ -3459,33 +3498,35 @@ Base.@propagate_inbounds function apply_stencil!( (li, lw, rw, ri) = window_bounds(space, bc), ) - (i, j, h) = hidx - bc_col = Spaces.column(bc, i,j,h) - ᶠspace = Spaces.FaceExtrudedFiniteDifferenceSpace(space) - ᶜspace = Spaces.CenterExtrudedFiniteDifferenceSpace(space) - ᶠlg_col = Spaces.column(Spaces.local_geometry_data(ᶠspace), i,j,h) - ᶜlg_col = Spaces.column(Spaces.local_geometry_data(ᶜspace), i,j,h) - ᶠlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶠlg_col) - ᶜlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶜlg_col) - lg_data = (ᶠlg_col_localmem, ᶜlg_col_localmem) - - try - bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data) - catch - @show bc_col + if true + (i, j, h) = hidx + bc_col = Spaces.column(bc, i,j,h) + ᶠspace = Spaces.FaceExtrudedFiniteDifferenceSpace(space) + ᶜspace = Spaces.CenterExtrudedFiniteDifferenceSpace(space) + ᶠlg_col = Spaces.column(Spaces.local_geometry_data(ᶠspace), i,j,h) + ᶜlg_col = Spaces.column(Spaces.local_geometry_data(ᶜspace), i,j,h) + # ᶠlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶠlg_col) + # ᶜlg_col_localmem = DataLayouts.rebuild_with_MArray(ᶜlg_col) + ᶠlg_col_localmem = nothing + ᶜlg_col_localmem = nothing + lg_data = (ᶠlg_col_localmem, ᶜlg_col_localmem) bc_localmem = transform_to_local_mem(bc_col, hidx, lg_data) + bc_used = bc_localmem + field_out_used = Fields.column(field_out, i,j,h) + else + bc_used = bc + field_out_used = field_out end - field_out_col = Fields.column(field_out, i,j,h) if !Topologies.isperiodic(Spaces.vertical_topology(space)) # left window lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}() @inbounds for idx in li:(lw - 1) setidx!( space, - field_out_col, + field_out_used, idx, hidx, - getidx(space, bc_localmem, lbw, idx, hidx), + getidx(space, bc_used, lbw, idx, hidx), ) end end @@ -3493,10 +3534,10 @@ Base.@propagate_inbounds function apply_stencil!( @inbounds for idx in lw:rw setidx!( space, - field_out_col, + field_out_used, idx, hidx, - getidx(space, bc_localmem, Interior(), idx, hidx), + getidx(space, bc_used, Interior(), idx, hidx), ) end if !Topologies.isperiodic(Spaces.vertical_topology(space)) @@ -3505,10 +3546,10 @@ Base.@propagate_inbounds function apply_stencil!( @inbounds for idx in (rw + 1):ri setidx!( space, - field_out_col, + field_out_used, idx, hidx, - getidx(space, bc_localmem, rbw, idx, hidx), + getidx(space, bc_used, rbw, idx, hidx), ) end end diff --git a/src/Topologies/topology2d.jl b/src/Topologies/topology2d.jl index fede3253e5..24b3d62203 100644 --- a/src/Topologies/topology2d.jl +++ b/src/Topologies/topology2d.jl @@ -23,7 +23,7 @@ Internally, we can refer to elements in several different ways: - `ridx`: "receive index": an index into the receive buffer of a ghost element. - `recv_elem_gidx[ridx] == gidx` """ -mutable struct Topology2D{ +struct Topology2D{ C <: ClimaComms.AbstractCommsContext, M <: Meshes.AbstractMesh{2}, EO, diff --git a/test/Operators/finitedifference/opt_examples.jl b/test/Operators/finitedifference/opt_examples.jl index 86ea28970b..b595cd543d 100644 --- a/test/Operators/finitedifference/opt_examples.jl +++ b/test/Operators/finitedifference/opt_examples.jl @@ -1,4 +1,9 @@ -import ClimaCore, ClimaComms, CUDA +#= +julia --project=test +using Revise; include(joinpath("test", "Operators", "finitedifference", "opt_examples.jl")) +=# +import ClimaCore, ClimaComms +ClimaComms.@import_required_backends using BenchmarkTools @isdefined(TU) || include( joinpath(pkgdir(ClimaCore), "test", "TestUtilities", "TestUtilities.jl"), @@ -439,118 +444,118 @@ function alloc_test_nested_expressions_13( end end -@testset "FD operator allocation tests" begin - FT = Float64 - n_elems = 1000 - domain = Domains.IntervalDomain( - Geometry.ZPoint{FT}(0.0), - Geometry.ZPoint{FT}(pi); - boundary_names = (:bottom, :top), - ) - mesh = Meshes.IntervalMesh(domain; nelems = n_elems) - cs = Spaces.CenterFiniteDifferenceSpace(mesh) - fs = Spaces.FaceFiniteDifferenceSpace(cs) - zc = getproperty(Fields.coordinate_field(cs), :z) - zf = getproperty(Fields.coordinate_field(fs), :z) - cfield_vars() = - (; cx = FT(0), cy = FT(0), cz = FT(0), cϕ = FT(0), cψ = FT(0)) - ffield_vars() = - (; fx = FT(0), fy = FT(0), fz = FT(0), fϕ = FT(0), fψ = FT(0)) - cntfield_vars() = (; nt = ntuple(i -> cfield_vars(), n_tuples)) - fntfield_vars() = (; nt = ntuple(i -> ffield_vars(), n_tuples)) - cfield = fill(cfield_vars(), cs) - ffield = fill(ffield_vars(), fs) - ntcfield = fill(cntfield_vars(), cs) - ntffield = fill(fntfield_vars(), fs) - wvec_glob = Geometry.WVector - - alloc_test_f2c_interp(cfield, ffield) - - alloc_test_c2f_interp( - cfield, - ffield, - Operators.InterpolateC2F(; - bottom = Operators.SetValue(0), - top = Operators.SetValue(0), - ), - ) - alloc_test_c2f_interp( - cfield, - ffield, - Operators.InterpolateC2F(; - bottom = Operators.SetGradient(wvec_glob(0)), - top = Operators.SetGradient(wvec_glob(0)), - ), - ) - alloc_test_c2f_interp( - cfield, - ffield, - Operators.InterpolateC2F(; - bottom = Operators.Extrapolate(), - top = Operators.Extrapolate(), - ), - ) - alloc_test_c2f_interp( - cfield, - ffield, - Operators.LeftBiasedC2F(; bottom = Operators.SetValue(0)), - ) - alloc_test_c2f_interp( - cfield, - ffield, - Operators.RightBiasedC2F(; top = Operators.SetValue(0)), - ) - - alloc_test_derivative( - cfield, - ffield, - Operators.DivergenceF2C(), - Operators.DivergenceC2F(; - bottom = Operators.SetValue(wvec_glob(0)), - top = Operators.SetValue(wvec_glob(0)), - ), - ) - alloc_test_derivative( - cfield, - ffield, - Operators.DivergenceF2C(; - bottom = Operators.SetValue(wvec_glob(0)), - top = Operators.SetValue(wvec_glob(0)), - ), - Operators.DivergenceC2F(; - bottom = Operators.SetValue(wvec_glob(0)), - top = Operators.SetValue(wvec_glob(0)), - ), - ) - alloc_test_derivative( - cfield, - ffield, - Operators.DivergenceF2C(; - bottom = Operators.Extrapolate(), - top = Operators.Extrapolate(), - ), - Operators.DivergenceC2F(; - bottom = Operators.SetDivergence(0), - top = Operators.SetDivergence(0), - ), - ) - - alloc_test_redefined_operators(cfield, ffield) - alloc_test_operators_in_loops(cfield, ffield) - alloc_test_nested_expressions_1(cfield, ffield) - alloc_test_nested_expressions_2(cfield, ffield) - alloc_test_nested_expressions_3(cfield, ffield) - alloc_test_nested_expressions_4(cfield, ffield) - alloc_test_nested_expressions_5(cfield, ffield) - alloc_test_nested_expressions_6(cfield, ffield) - alloc_test_nested_expressions_7(cfield, ffield) - alloc_test_nested_expressions_8(cfield, ffield) - alloc_test_nested_expressions_9(cfield, ffield) - alloc_test_nested_expressions_10(cfield, ffield) - alloc_test_nested_expressions_11(cfield, ffield) - alloc_test_nested_expressions_12(cfield, ffield, ntcfield, ntffield) - alloc_test_nested_expressions_13(cfield, ffield, ntcfield, ntffield, FT) -end +# @testset "FD operator allocation tests" begin +# FT = Float64 +# n_elems = 1000 +# domain = Domains.IntervalDomain( +# Geometry.ZPoint{FT}(0.0), +# Geometry.ZPoint{FT}(pi); +# boundary_names = (:bottom, :top), +# ) +# mesh = Meshes.IntervalMesh(domain; nelems = n_elems) +# cs = Spaces.CenterFiniteDifferenceSpace(mesh) +# fs = Spaces.FaceFiniteDifferenceSpace(cs) +# zc = getproperty(Fields.coordinate_field(cs), :z) +# zf = getproperty(Fields.coordinate_field(fs), :z) +# cfield_vars() = +# (; cx = FT(0), cy = FT(0), cz = FT(0), cϕ = FT(0), cψ = FT(0)) +# ffield_vars() = +# (; fx = FT(0), fy = FT(0), fz = FT(0), fϕ = FT(0), fψ = FT(0)) +# cntfield_vars() = (; nt = ntuple(i -> cfield_vars(), n_tuples)) +# fntfield_vars() = (; nt = ntuple(i -> ffield_vars(), n_tuples)) +# cfield = fill(cfield_vars(), cs) +# ffield = fill(ffield_vars(), fs) +# ntcfield = fill(cntfield_vars(), cs) +# ntffield = fill(fntfield_vars(), fs) +# wvec_glob = Geometry.WVector + +# alloc_test_f2c_interp(cfield, ffield) + +# alloc_test_c2f_interp( +# cfield, +# ffield, +# Operators.InterpolateC2F(; +# bottom = Operators.SetValue(0), +# top = Operators.SetValue(0), +# ), +# ) +# alloc_test_c2f_interp( +# cfield, +# ffield, +# Operators.InterpolateC2F(; +# bottom = Operators.SetGradient(wvec_glob(0)), +# top = Operators.SetGradient(wvec_glob(0)), +# ), +# ) +# alloc_test_c2f_interp( +# cfield, +# ffield, +# Operators.InterpolateC2F(; +# bottom = Operators.Extrapolate(), +# top = Operators.Extrapolate(), +# ), +# ) +# alloc_test_c2f_interp( +# cfield, +# ffield, +# Operators.LeftBiasedC2F(; bottom = Operators.SetValue(0)), +# ) +# alloc_test_c2f_interp( +# cfield, +# ffield, +# Operators.RightBiasedC2F(; top = Operators.SetValue(0)), +# ) + +# alloc_test_derivative( +# cfield, +# ffield, +# Operators.DivergenceF2C(), +# Operators.DivergenceC2F(; +# bottom = Operators.SetValue(wvec_glob(0)), +# top = Operators.SetValue(wvec_glob(0)), +# ), +# ) +# alloc_test_derivative( +# cfield, +# ffield, +# Operators.DivergenceF2C(; +# bottom = Operators.SetValue(wvec_glob(0)), +# top = Operators.SetValue(wvec_glob(0)), +# ), +# Operators.DivergenceC2F(; +# bottom = Operators.SetValue(wvec_glob(0)), +# top = Operators.SetValue(wvec_glob(0)), +# ), +# ) +# alloc_test_derivative( +# cfield, +# ffield, +# Operators.DivergenceF2C(; +# bottom = Operators.Extrapolate(), +# top = Operators.Extrapolate(), +# ), +# Operators.DivergenceC2F(; +# bottom = Operators.SetDivergence(0), +# top = Operators.SetDivergence(0), +# ), +# ) + +# alloc_test_redefined_operators(cfield, ffield) +# alloc_test_operators_in_loops(cfield, ffield) +# alloc_test_nested_expressions_1(cfield, ffield) +# alloc_test_nested_expressions_2(cfield, ffield) +# alloc_test_nested_expressions_3(cfield, ffield) +# alloc_test_nested_expressions_4(cfield, ffield) +# alloc_test_nested_expressions_5(cfield, ffield) +# alloc_test_nested_expressions_6(cfield, ffield) +# alloc_test_nested_expressions_7(cfield, ffield) +# alloc_test_nested_expressions_8(cfield, ffield) +# alloc_test_nested_expressions_9(cfield, ffield) +# alloc_test_nested_expressions_10(cfield, ffield) +# alloc_test_nested_expressions_11(cfield, ffield) +# alloc_test_nested_expressions_12(cfield, ffield, ntcfield, ntffield) +# alloc_test_nested_expressions_13(cfield, ffield, ntcfield, ntffield, FT) +# end # https://github.com/CliMA/ClimaCore.jl/issues/1602 @@ -567,7 +572,7 @@ function set_ᶠuₕ³!(ᶜx, ᶠx) end # @testset "Inference/allocations when broadcasting types" begin FT = Float64 - cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT; zelem = 25, helem = 10) + cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT; zelem = 63, helem = 20) fspace = Spaces.FaceExtrudedFiniteDifferenceSpace(cspace) device = ClimaComms.device(cspace) @info "device = $device" @@ -577,10 +582,10 @@ end p_allocated = @allocated set_ᶠuₕ³!(ᶜx, ᶠx) @show p_allocated - trial = if device isa ClimaComms.CUDADevice - @benchmark CUDA.@sync set_ᶠuₕ³!($ ᶜx, $ᶠx) - else - @benchmark set_ᶠuₕ³!($ ᶜx, $ᶠx) - end - show(stdout, MIME("text/plain"), trial) + # trial = if device isa ClimaComms.CUDADevice + # @benchmark CUDA.@sync set_ᶠuₕ³!($ ᶜx, $ᶠx) + # else + # @benchmark set_ᶠuₕ³!($ ᶜx, $ᶠx) + # end + # show(stdout, MIME("text/plain"), trial) # end