From 9d5f538ed6f9c329af95cb5ba13fc1784f8453ca Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Wed, 31 Jul 2024 10:09:12 -0400 Subject: [PATCH] Add linear index support for pointwise kernels --- ext/cuda/data_layouts.jl | 13 ++ ext/cuda/data_layouts_copyto.jl | 113 +++-------- ext/cuda/data_layouts_fill.jl | 4 +- src/DataLayouts/DataLayouts.jl | 33 +++ src/DataLayouts/broadcast.jl | 1 + src/DataLayouts/copyto.jl | 18 +- src/DataLayouts/fill.jl | 61 +----- src/DataLayouts/has_uniform_datalayouts.jl | 60 ++++++ src/DataLayouts/non_extruded_broadcasted.jl | 158 +++++++++++++++ src/DataLayouts/struct.jl | 147 ++++++++++++++ src/DataLayouts/to_linear_index.jl | 49 +++++ test/DataLayouts/unit_copyto.jl | 3 +- .../unit_has_uniform_datalayouts.jl | 49 +++++ test/DataLayouts/unit_linear_indexing.jl | 191 ++++++++++++++++++ test/runtests.jl | 1 + 15 files changed, 750 insertions(+), 151 deletions(-) create mode 100644 src/DataLayouts/has_uniform_datalayouts.jl create mode 100644 src/DataLayouts/non_extruded_broadcasted.jl create mode 100644 src/DataLayouts/to_linear_index.jl create mode 100644 test/DataLayouts/unit_has_uniform_datalayouts.jl create mode 100644 test/DataLayouts/unit_linear_indexing.jl diff --git a/ext/cuda/data_layouts.jl b/ext/cuda/data_layouts.jl index 20cc9d7178..7178fccb9f 100644 --- a/ext/cuda/data_layouts.jl +++ b/ext/cuda/data_layouts.jl @@ -53,3 +53,16 @@ function Adapt.adapt_structure( end, ) end + +import Adapt +import CUDA +function Adapt.adapt_structure( + to::CUDA.KernelAdaptor, + bc::DataLayouts.NonExtrudedBroadcasted{Style}, +) where {Style} + DataLayouts.NonExtrudedBroadcasted{Style}( + adapt_f(to, bc.f), + Adapt.adapt(to, bc.args), + Adapt.adapt(to, bc.axes), + ) +end diff --git a/ext/cuda/data_layouts_copyto.jl b/ext/cuda/data_layouts_copyto.jl index 5439a61527..c85a737b06 100644 --- a/ext/cuda/data_layouts_copyto.jl +++ b/ext/cuda/data_layouts_copyto.jl @@ -1,90 +1,9 @@ +import ClimaCore.DataLayouts: + to_non_extruded_broadcasted, has_uniform_datalayouts DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA() -function knl_copyto!(dest, src) - - i = CUDA.threadIdx().x - j = CUDA.threadIdx().y - - h = CUDA.blockIdx().x - v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z - - if v <= size(dest, 4) - I = CartesianIndex((i, j, 1, v, h)) - @inbounds dest[I] = src[I] - end - return nothing -end - -function Base.copyto!( - dest::IJFH{S, Nij, Nh}, - bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, - ::ToCUDA, -) where {S, Nij, Nh} - if Nh > 0 - auto_launch!( - knl_copyto!, - (dest, bc), - dest; - threads_s = (Nij, Nij), - blocks_s = (Nh, 1), - ) - end - return dest -end - -function Base.copyto!( - dest::VIJFH{S, Nv, Nij, Nh}, - bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, - ::ToCUDA, -) where {S, Nv, Nij, Nh} - if Nv > 0 && Nh > 0 - Nv_per_block = min(Nv, fld(256, Nij * Nij)) - Nv_blocks = cld(Nv, Nv_per_block) - auto_launch!( - knl_copyto!, - (dest, bc), - dest; - threads_s = (Nij, Nij, Nv_per_block), - blocks_s = (Nh, Nv_blocks), - ) - end - return dest -end - -function Base.copyto!( - dest::VF{S, Nv}, - bc::DataLayouts.BroadcastedUnionVF{S, Nv}, - ::ToCUDA, -) where {S, Nv} - if Nv > 0 - auto_launch!( - knl_copyto!, - (dest, bc), - dest; - threads_s = (1, 1), - blocks_s = (1, Nv), - ) - end - return dest -end - -function Base.copyto!( - dest::DataF{S}, - bc::DataLayouts.BroadcastedUnionDataF{S}, - ::ToCUDA, -) where {S} - auto_launch!( - knl_copyto!, - (dest, bc), - dest; - threads_s = (1, 1), - blocks_s = (1, 1), - ) - return dest -end - import ClimaCore.DataLayouts: isascalar -function knl_copyto_flat!(dest::AbstractData, bc, us) +function knl_copyto_cart!(dest::AbstractData, bc, us) @inbounds begin tidx = thread_index() if tidx ≤ get_N(us) @@ -96,11 +15,25 @@ function knl_copyto_flat!(dest::AbstractData, bc, us) return nothing end +function knl_copyto_linear!(dest::AbstractData, bc, us) + @inbounds begin + tidx = thread_index() + if tidx ≤ get_N(us) + dest[tidx] = bc[tidx] + end + end + return nothing +end + function cuda_copyto!(dest::AbstractData, bc) (_, _, Nv, Nh) = DataLayouts.universal_size(dest) + (Nv > 0 && Nh > 0) || return dest us = DataLayouts.UniversalSize(dest) - if Nv > 0 && Nh > 0 - auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true) + if has_uniform_datalayouts(bc) + bc′ = to_non_extruded_broadcasted(bc) + auto_launch!(knl_copyto_linear!, (dest, bc′, us), dest; auto = true) + else + auto_launch!(knl_copyto_cart!, (dest, bc, us), dest; auto = true) end return dest end @@ -108,12 +41,12 @@ end # TODO: can we use CUDA's luanch configuration for all data layouts? # Currently, it seems to have a slight performance degradation. #! format: off -# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc) +Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc) Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc) Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc) Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc) Base.copyto!(dest::VIFH{S, Nv, Ni, Nh}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, Nh}, ::ToCUDA) where {S, Nv, Ni, Nh} = cuda_copyto!(dest, bc) -# Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc) -# Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc) -# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc) +Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc) +Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc) +Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc) #! format: on diff --git a/ext/cuda/data_layouts_fill.jl b/ext/cuda/data_layouts_fill.jl index 9999c65a8a..5d5f3b9236 100644 --- a/ext/cuda/data_layouts_fill.jl +++ b/ext/cuda/data_layouts_fill.jl @@ -2,9 +2,7 @@ function knl_fill_flat!(dest::AbstractData, val, us) @inbounds begin tidx = thread_index() if tidx ≤ get_N(us) - n = size(dest) - I = kernel_indexes(tidx, n) - @inbounds dest[I] = val + @inbounds dest[tidx] = val end end return nothing diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index 2d65ef6c9a..276dd3539f 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -1523,6 +1523,37 @@ get_Nij(::IF{S, Nij}) where {S, Nij} = Nij @inline field_dim(::VIJFH) = 4 @inline field_dim(::VIFH) = 3 +# Returns the size of the backing array. +@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} = + (Nij, Nij, Nk, 1, Nv, Nh) +@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh) +@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh) +@inline array_size(::DataF{S}) where {S} = (1,) +@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1) +@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1) +@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1) +@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = + (Nv, Nij, Nij, 1, Nh) +@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = + (Nv, Ni, 1, Nh) + +@inline farray_size( + data::IJKFVH{S, Nij, Nk, Nv, Nh}, +) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, ncomponents(data), Nv, Nh) +@inline farray_size(data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = + (Nij, Nij, ncomponents(data), Nh) +@inline farray_size(data::IFH{S, Ni, Nh}) where {S, Ni, Nh} = + (Ni, ncomponents(data), Nh) +@inline farray_size(data::DataF{S}) where {S} = (ncomponents(data),) +@inline farray_size(data::IJF{S, Nij}) where {S, Nij} = + (Nij, Nij, ncomponents(data)) +@inline farray_size(data::IF{S, Ni}) where {S, Ni} = (Ni, ncomponents(data)) +@inline farray_size(data::VF{S, Nv}) where {S, Nv} = (Nv, ncomponents(data)) +@inline farray_size(data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} = + (Nv, Nij, Nij, ncomponents(data), Nh) +@inline farray_size(data::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} = + (Nv, Ni, ncomponents(data), Nh) + #! format: off @inline to_data_specific(::IJFH, I::CartesianIndex) = CartesianIndex(I.I[1], I.I[2], 1, I.I[5]) @inline to_data_specific(::IFH, I::CartesianIndex) = CartesianIndex(I.I[1], 1, I.I[5]) @@ -1600,10 +1631,12 @@ _device_dispatch(x::AbstractData) = _device_dispatch(parent(x)) _device_dispatch(x::SArray) = ToCPU() _device_dispatch(x::MArray) = ToCPU() +include("non_extruded_broadcasted.jl") include("copyto.jl") include("fused_copyto.jl") include("fill.jl") include("mapreduce.jl") +include("has_uniform_datalayouts.jl") slab_index(i, j) = CartesianIndex(i, j, 1, 1, 1) slab_index(i) = CartesianIndex(i, 1, 1, 1, 1) diff --git a/src/DataLayouts/broadcast.jl b/src/DataLayouts/broadcast.jl index 9da4d81fc8..2ce83e4d3d 100644 --- a/src/DataLayouts/broadcast.jl +++ b/src/DataLayouts/broadcast.jl @@ -73,6 +73,7 @@ DataSlab2DStyle(::Type{VIJFHStyle{Nv, Nij, Nh, A}}) where {Nv, Nij, Nh, A} = ##### #! format: off +const BroadcastedUnionData = Union{Base.Broadcast.Broadcasted{<:DataStyle}, AbstractData} const BroadcastedUnionIJFH{S, Nij, Nh, A} = Union{Base.Broadcast.Broadcasted{IJFHStyle{Nij, Nh, A}}, IJFH{S, Nij, Nh, A}} const BroadcastedUnionIFH{S, Ni, Nh, A} = Union{Base.Broadcast.Broadcasted{IFHStyle{Ni, Nh, A}}, IFH{S, Ni, Nh, A}} const BroadcastedUnionIJF{S, Nij, A} = Union{Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}, IJF{S, Nij, A}} diff --git a/src/DataLayouts/copyto.jl b/src/DataLayouts/copyto.jl index 4a94638edb..c887c11476 100644 --- a/src/DataLayouts/copyto.jl +++ b/src/DataLayouts/copyto.jl @@ -2,10 +2,22 @@ ##### Dispatching and edge cases ##### -Base.copyto!( - dest::AbstractData, +function Base.copyto!( + dest::AbstractData{S}, bc::Union{AbstractData, Base.Broadcast.Broadcasted}, -) = Base.copyto!(dest, bc, device_dispatch(dest)) +) where {S} + dev = device_dispatch(dest) + if dev isa ToCPU && has_uniform_datalayouts(bc) + # Specialize on linear indexing case: + bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc)) + @inbounds @simd for I in 1:get_N(UniversalSize(dest)) + dest[I] = convert(S, bc′[I]) + end + else + Base.copyto!(dest, bc, device_dispatch(dest)) + end + return dest +end # Specialize on non-Broadcasted objects function Base.copyto!(dest::D, src::D) where {D <: AbstractData} diff --git a/src/DataLayouts/fill.jl b/src/DataLayouts/fill.jl index c942b0c959..e1998c93aa 100644 --- a/src/DataLayouts/fill.jl +++ b/src/DataLayouts/fill.jl @@ -1,60 +1,13 @@ -function Base.fill!(data::IJFH, val, ::ToCPU) - (_, _, _, _, Nh) = size(data) - @inbounds for h in 1:Nh - fill!(slab(data, h), val) +function Base.fill!(dest::AbstractData, val, ::ToCPU) + @inbounds @simd for I in 1:get_N(UniversalSize(dest)) + dest[I] = val end - return data + return dest end -function Base.fill!(data::IFH, val, ::ToCPU) - (_, _, _, _, Nh) = size(data) - @inbounds for h in 1:Nh - fill!(slab(data, h), val) - end - return data -end - -function Base.fill!(data::DataF, val, ::ToCPU) - @inbounds data[] = val - return data -end - -function Base.fill!(data::IJF{S, Nij}, val, ::ToCPU) where {S, Nij} - @inbounds for j in 1:Nij, i in 1:Nij - data[CartesianIndex(i, j, 1, 1, 1)] = val - end - return data -end - -function Base.fill!(data::IF{S, Ni}, val, ::ToCPU) where {S, Ni} - @inbounds for i in 1:Ni - data[CartesianIndex(i, 1, 1, 1, 1)] = val - end - return data -end - -function Base.fill!(data::VF, val, ::ToCPU) - Nv = nlevels(data) - @inbounds for v in 1:Nv - data[CartesianIndex(1, 1, 1, v, 1)] = val - end - return data -end - -function Base.fill!(data::VIJFH, val, ::ToCPU) - (Ni, Nj, _, Nv, Nh) = size(data) - @inbounds for h in 1:Nh, v in 1:Nv - fill!(slab(data, v, h), val) - end - return data -end - -function Base.fill!(data::VIFH, val, ::ToCPU) - (Ni, _, _, Nv, Nh) = size(data) - @inbounds for h in 1:Nh, v in 1:Nv - fill!(slab(data, v, h), val) - end - return data +function Base.fill!(dest::DataF, val, ::ToCPU) + @inbounds dest[] = val + return dest end Base.fill!(dest::AbstractData, val) = diff --git a/src/DataLayouts/has_uniform_datalayouts.jl b/src/DataLayouts/has_uniform_datalayouts.jl new file mode 100644 index 0000000000..1a919a9b0c --- /dev/null +++ b/src/DataLayouts/has_uniform_datalayouts.jl @@ -0,0 +1,60 @@ +@inline function first_datalayout_in_bc(args::Tuple, rargs...) + x1 = first_datalayout_in_bc(args[1], rargs...) + x1 isa AbstractData && return x1 + return first_datalayout_in_bc(Base.tail(args), rargs...) +end + +@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) = + first_datalayout_in_bc(args[1], rargs...) +@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing +@inline first_datalayout_in_bc(x) = nothing +@inline first_datalayout_in_bc(x::AbstractData) = x + +@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) = + first_datalayout_in_bc(bc.args) + +@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) = + truesofar && + _has_uniform_datalayouts(truesofar, start, args[1], rargs...) && + _has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...) + +@inline _has_uniform_datalayouts_args( + truesofar, + start, + args::Tuple{Any}, + rargs..., +) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...) +@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) = + truesofar + +@inline function _has_uniform_datalayouts( + truesofar, + start, + bc::Base.Broadcast.Broadcasted, +) + return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args) +end +for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH) + @eval begin + @inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true + end +end +@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false +@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar + +""" + has_uniform_datalayouts +Find the first datalayout in the broadcast expression (BCE), +and compares against every other datalayout in the BCE. Returns + - `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH) + - `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH) +Note: a broadcasted object can have different _types_, + e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}` + but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`. +""" +function has_uniform_datalayouts end + +@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) = + _has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args) + +@inline has_uniform_datalayouts(bc::AbstractData) = true diff --git a/src/DataLayouts/non_extruded_broadcasted.jl b/src/DataLayouts/non_extruded_broadcasted.jl new file mode 100644 index 0000000000..ce38728fe7 --- /dev/null +++ b/src/DataLayouts/non_extruded_broadcasted.jl @@ -0,0 +1,158 @@ +#! format: off +# ============================================================ Adapted from Base.Broadcast (julia version 1.10.4) +import Base.Broadcast: BroadcastStyle +struct NonExtrudedBroadcasted{ + Style <: Union{Nothing, BroadcastStyle}, + Axes, + F, + Args <: Tuple, +} <: Base.AbstractBroadcasted + style::Style + f::F + args::Args + axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `NonExtrudedBroadcasted`) + + NonExtrudedBroadcasted(style::Union{Nothing, BroadcastStyle}, f::Tuple, args::Tuple) = + error() # disambiguation: tuple is not callable + function NonExtrudedBroadcasted( + style::Union{Nothing, BroadcastStyle}, + f::F, + args::Tuple, + axes = nothing, + ) where {F} + # using Core.Typeof rather than F preserves inferrability when f is a type + return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}( + style, + f, + args, + axes, + ) + end + function NonExtrudedBroadcasted(f::F, args::Tuple, axes = nothing) where {F} + NonExtrudedBroadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes) + end + function NonExtrudedBroadcasted{Style}(f::F, args, axes = nothing) where {Style, F} + return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}( + Style()::Style, + f, + args, + axes, + ) + end + function NonExtrudedBroadcasted{Style, Axes, F, Args}( + f, + args, + axes, + ) where {Style, Axes, F, Args} + return new{Style, Axes, F, Args}(Style()::Style, f, args, axes) + end +end + +@inline to_non_extruded_broadcasted(bc::Base.Broadcast.Broadcasted) = + NonExtrudedBroadcasted(bc.style, bc.f, bc.args, bc.axes) +@inline to_non_extruded_broadcasted(x) = x +NonExtrudedBroadcasted(bc::Base.Broadcast.Broadcasted) = to_non_extruded_broadcasted(bc) + +@inline to_non_extruded_broadcasted_args(args::Tuple, inds...) = ( + to_non_extruded_broadcasted(args[1], inds...), + to_non_extruded_broadcasted_args(Base.tail(args), inds...)..., +) +@inline to_non_extruded_broadcasted_args(args::Tuple{Any}, inds...) = + (to_non_extruded_broadcasted(args[1], inds...),) +@inline to_non_extruded_broadcasted_args(args::Tuple{}, inds...) = () + +@inline function to_non_extruded_broadcasted(bc::Base.Broadcast.Broadcasted, symb, axes) + Base.Broadcast.Broadcasted( + bc.f, + to_non_extruded_broadcasted_args(bc.args, symb, axes), + axes, + ) +end +@inline to_non_extruded_broadcasted(x, symb, axes) = x + +@inline function Base.getindex( + bc::NonExtrudedBroadcasted, + I::Union{Integer, CartesianIndex}, +) + @boundscheck Base.checkbounds(bc, I) # is this really the only issue? + @inbounds _broadcast_getindex(bc, I) +end + +# --- here, we define our own bounds checks +@inline Base.checkbounds(bc::NonExtrudedBroadcasted, I::Integer) = + # Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,)) # from Base + Base.checkbounds_indices(Bool, (Base.OneTo(n_dofs(bc)),), (I,)) || Base.throw_boundserror(bc, (I,)) + +import StaticArrays +to_tuple(t::Tuple) = t +to_tuple(t::NTuple{N, <: Base.OneTo}) where {N} = map(x->x.stop, t) +to_tuple(t::NTuple{N, <: StaticArrays.SOneTo}) where {N} = map(x->x.stop, t) +n_dofs(bc) = prod(to_tuple(axes(bc))) +# --- + +Base.@propagate_inbounds _broadcast_getindex( + A::Union{Ref, AbstractArray{<:Any, 0}, Number}, + I::Integer, +) = A[] # Scalar-likes can just ignore all indices +Base.@propagate_inbounds _broadcast_getindex( + ::Ref{Type{T}}, + I::Integer, +) where {T} = T +# Tuples are statically known to be singleton or vector-like +Base.@propagate_inbounds _broadcast_getindex(A::Tuple{Any}, I::Integer) = A[1] +Base.@propagate_inbounds _broadcast_getindex(A::Tuple, I::Integer) = A[I[1]] +# Everything else falls back to dynamically dropping broadcasted indices based upon its axes +# Base.@propagate_inbounds _broadcast_getindex(A, I) = A[newindex(A, I)] +Base.@propagate_inbounds _broadcast_getindex(A, I::Integer) = A[I] +Base.@propagate_inbounds function _broadcast_getindex( + bc::NonExtrudedBroadcasted{<:Any, <:Any, <:Any, <:Any}, + I::Integer, +) + args = _getindex(bc.args, I) + return _broadcast_getindex_evalf(bc.f, args...) +end +@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} = + f(args...) # not propagate_inbounds +Base.@propagate_inbounds _getindex(args::Tuple, I) = + (_broadcast_getindex(args[1], I), _getindex(Base.tail(args), I)...) +Base.@propagate_inbounds _getindex(args::Tuple{Any}, I) = + (_broadcast_getindex(args[1], I),) +Base.@propagate_inbounds _getindex(args::Tuple{}, I) = () + +@inline Base.axes(bc::NonExtrudedBroadcasted) = _axes(bc, bc.axes) +_axes(::NonExtrudedBroadcasted, axes::Tuple) = axes +@inline _axes(bc::NonExtrudedBroadcasted, ::Nothing) = + Base.Broadcast.combine_axes(bc.args...) +_axes(bc::NonExtrudedBroadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}, ::Nothing) = () +@inline Base.axes(bc::NonExtrudedBroadcasted{<:Any, <:NTuple{N}}, d::Integer) where {N} = + d <= N ? axes(bc)[d] : OneTo(1) +Base.IndexStyle(::Type{<:NonExtrudedBroadcasted{<:Any, <:Tuple{Any}}}) = IndexLinear() + + +# ============================================================ + +#! format: on +# Datalayouts +@propagate_inbounds function Base.getindex( + data::AbstractData{S}, + I::Integer, +) where {S} + s_array = farray_size(data) + ss = StaticSize(s_array, field_dim(data)) + @inbounds get_struct_linear(parent(data), S, Val(field_dim(data)), I, ss) +end +@propagate_inbounds function Base.setindex!( + data::AbstractData{S}, + val, + I::Integer, +) where {S} + s_array = farray_size(data) + ss = StaticSize(s_array, field_dim(data)) + @inbounds set_struct_linear!( + parent(data), + convert(S, val), + Val(field_dim(data)), + I, + ss, + ) +end diff --git a/src/DataLayouts/struct.jl b/src/DataLayouts/struct.jl index c20b580734..6e92616460 100644 --- a/src/DataLayouts/struct.jl +++ b/src/DataLayouts/struct.jl @@ -218,6 +218,153 @@ Base.@propagate_inbounds function get_struct( @inbounds return array[start_index] end + +abstract type _Size end +struct DynamicSize <: _Size end +struct StaticSize{S_array, FD} <: _Size + function StaticSize{S, FD}() where {S, FD} + new{S::Tuple{Vararg{Int}}, FD}() + end +end + +Base.@pure StaticSize(s::Tuple{Vararg{Int}}, FD) = StaticSize{s, FD}() + +# Some @pure convenience functions for `StaticSize` +s_field_dim_1(::Type{StaticSize{S, FD}}) where {S, FD} = + ntuple(j -> j == FD ? 1 : S[j], length(S)) +s_field_dim_1(::StaticSize{S, FD}) where {S, FD} = + ntuple(j -> j == FD ? 1 : S[j], length(S)) + +Base.@pure get(::Type{StaticSize{S}}) where {S} = S +Base.@pure get(::StaticSize{S}) where {S} = S +Base.@pure Base.getindex(::StaticSize{S}, i::Int) where {S} = + i <= length(S) ? S[i] : 1 +Base.@pure Base.ndims(::StaticSize{S}) where {S} = length(S) +Base.@pure Base.ndims(::Type{StaticSize{S}}) where {S} = length(S) +Base.@pure Base.length(::StaticSize{S}) where {S} = prod(S) + +Base.@propagate_inbounds cart_inds(n::NTuple) = + @inbounds CartesianIndices(map(x -> Base.OneTo(x), n)) +Base.@propagate_inbounds linear_inds(n::NTuple) = + @inbounds LinearIndices(map(x -> Base.OneTo(x), n)) + +include("to_linear_index.jl") # TODO: delete if not needed + +@inline function offset_index( + base_index::Integer, + start_index::Integer, + ::Val{D}, + field_offset, + ss::StaticSize{SS}; +) where {D, SS} + # TODO: compute this offset directly without going through CartesianIndex + SS1 = s_field_dim_1(typeof(ss)) + ci = cart_inds(SS1)[base_index] + ci_poff = CartesianIndex( + ntuple(n -> n == D ? ci[n] + field_offset : ci[n], ndims(ss)), + ) + li = linear_inds(SS)[ci_poff] + return li +end + +Base.@propagate_inbounds @generated function get_struct_linear( + array::AbstractArray{T}, + ::Type{S}, + ::Val{D}, + start_index::Integer, + ss::StaticSize, + base_index = start_index, +) where {T, S, D} + tup = :(()) + for i in 1:fieldcount(S) + push!( + tup.args, + :(get_struct_linear( + array, + fieldtype(S, $i), + Val($D), + offset_index( + base_index, + start_index, + Val($D), + $(fieldtypeoffset(T, S, Val(i))), + ss, + ), + ss, + base_index, + )), + ) + end + return quote + Base.@_propagate_inbounds_meta + @inbounds bypass_constructor(S, $tup) + end +end + +# recursion base case: hit array type is the same as the struct leaf type +Base.@propagate_inbounds function get_struct_linear( + array::AbstractArray{S}, + ::Type{S}, + ::Val{D}, + start_index::Integer, + us::StaticSize, + base_index = start_index, +) where {S, D} + @inbounds return array[start_index] +end + +""" + set_struct!(array, val::S, Val(D), start_index) + +Store an object `val` of type `S` packed along the `D` dimension, into `array`, +starting at `start_index`. +""" +Base.@propagate_inbounds @generated function set_struct_linear!( + array::AbstractArray{T}, + val::S, + ::Val{D}, + start_index::Integer, + ss::StaticSize, + base_index = start_index, +) where {T, S, D} + ex = quote + Base.@_propagate_inbounds_meta + end + for i in 1:fieldcount(S) + push!( + ex.args, + :(set_struct_linear!( + array, + getfield(val, $i), + Val($D), + offset_index( + base_index, + start_index, + Val($D), + $(fieldtypeoffset(T, S, Val(i))), + ss, + ), + ss, + base_index, + )), + ) + end + push!(ex.args, :(return val)) + return ex +end + +Base.@propagate_inbounds function set_struct_linear!( + array::AbstractArray{S}, + val::S, + ::Val{D}, + start_index::Integer, + us::StaticSize, + base_index = start_index, +) where {S, D} + @inbounds array[start_index] = val + val +end + """ set_struct!(array, val::S, Val(D), start_index) diff --git a/src/DataLayouts/to_linear_index.jl b/src/DataLayouts/to_linear_index.jl new file mode 100644 index 0000000000..26cd52c49b --- /dev/null +++ b/src/DataLayouts/to_linear_index.jl @@ -0,0 +1,49 @@ +_to_linear_index(A::AbstractArray, li, ci) = + _to_linear_index(A, Base.to_indices(li, (ci,))...) +_to_linear_index(A::AbstractArray, I::Integer...) = (@inline; _sub2ind(A, I...)) + +function _sub2ind(A::AbstractArray, I...) + @inline + _sub2ind(axes(A), I...) +end + +# 0-dimensional arrays and indexing with [] +_sub2ind(::Tuple{}) = 1 +_sub2ind(::Base.DimsInteger) = 1 +# _sub2ind(::Indices) = 1 +_sub2ind(::Tuple{}, I::Integer...) = (@inline; _sub2ind_recurse((), 1, 1, I...)) + +# Generic cases +_sub2ind(dims::Base.DimsInteger, I::Integer...) = + (@inline; _sub2ind_recurse(dims, 1, 1, I...)) +_sub2ind(inds::Base.Indices, I::Integer...) = + (@inline; _sub2ind_recurse(inds, 1, 1, I...)) +# In 1d, there's a question of whether we're doing cartesian indexing +# or linear indexing. Support only the former. +_sub2ind(inds::Base.Indices{1}, I::Integer...) = throw( + ArgumentError("Linear indexing is not defined for one-dimensional arrays"), +) +_sub2ind(inds::Tuple{Base.OneTo}, I::Integer...) = + (@inline; _sub2ind_recurse(inds, 1, 1, I...)) # only OneTo is safe +_sub2ind(inds::Tuple{Base.OneTo}, i::Integer) = i + +_sub2ind_recurse(::Any, L, ind) = ind +function _sub2ind_recurse(::Tuple{}, L, ind, i::Integer, I::Integer...) + @inline + _sub2ind_recurse((), L, ind + (i - 1) * L, I...) +end +function _sub2ind_recurse(inds, L, ind, i::Integer, I::Integer...) + @inline + r1 = inds[1] + _sub2ind_recurse( + Base.tail(inds), + nextL(L, r1), + ind + offsetin(i, r1) * L, + I..., + ) +end + +nextL(L, l::Integer) = L * l +nextL(L, r::AbstractUnitRange) = L * length(r) +offsetin(i, l::Integer) = i - 1 +offsetin(i, r::AbstractUnitRange) = i - first(r) diff --git a/test/DataLayouts/unit_copyto.jl b/test/DataLayouts/unit_copyto.jl index 0b304a4f81..c9cb676505 100644 --- a/test/DataLayouts/unit_copyto.jl +++ b/test/DataLayouts/unit_copyto.jl @@ -1,5 +1,6 @@ #= -julia --project +julia --check-bounds=yes --project +ENV["CLIMACOMMS_DEVICE"] = "CPU"; using Revise; include(joinpath("test", "DataLayouts", "unit_copyto.jl")) =# using Test diff --git a/test/DataLayouts/unit_has_uniform_datalayouts.jl b/test/DataLayouts/unit_has_uniform_datalayouts.jl new file mode 100644 index 0000000000..4735b065f1 --- /dev/null +++ b/test/DataLayouts/unit_has_uniform_datalayouts.jl @@ -0,0 +1,49 @@ +#= +julia --project +using Revise; include(joinpath("test", "DataLayouts", "has_uniform_datalayouts.jl")) +=# +using Test +using ClimaCore.DataLayouts +import ClimaCore.Geometry +import ClimaComms +import LazyBroadcast: @lazy +using StaticArrays +import Random +Random.seed!(1234) + +@testset "has_uniform_datalayouts" begin + device = ClimaComms.device() + device_zeros(args...) = ClimaComms.array_type(device)(zeros(args...)) + FT = Float64 + S = FT + Nf = 1 + Nv = 4 + Nij = 3 + Nh = 5 + Nk = 6 +#! format: off + data_DataF = DataF{S}(device_zeros(FT,Nf)); + data_IJFH = IJFH{S, Nij, Nh}(device_zeros(FT,Nij,Nij,Nf,Nh)); + data_IFH = IFH{S, Nij, Nh}(device_zeros(FT,Nij,Nf,Nh)); + data_IJF = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); + data_IF = IF{S, Nij}(device_zeros(FT,Nij,Nf)); + data_VF = VF{S, Nv}(device_zeros(FT,Nv,Nf)); + data_VIJFH = VIJFH{S,Nv,Nij,Nh}(device_zeros(FT,Nv,Nij,Nij,Nf,Nh)); + data_VIFH = VIFH{S, Nv, Nij, Nh}(device_zeros(FT,Nv,Nij,Nf,Nh)); +#! format: on + + bc = @lazy @. data_VIFH + data_VIFH + @test DataLayouts.has_uniform_datalayouts(bc) + bc = @lazy @. data_IJFH + data_VF + @test !DataLayouts.has_uniform_datalayouts(bc) + + data_VIJFHᶜ = VIJFH{S, Nv, Nij, Nh}(device_zeros(FT, Nv, Nij, Nij, Nf, Nh)) + data_VIJFHᶠ = + VIJFH{S, Nv + 1, Nij, Nh}(device_zeros(FT, Nv + 1, Nij, Nij, Nf, Nh)) + + # This is not a valid broadcast expression, + # but these two datalayouts can exist in a + # valid broadcast expression (e.g., interpolation). + bc = @lazy @. data_VIJFHᶜ + data_VIJFHᶠ + @test DataLayouts.has_uniform_datalayouts(bc) +end diff --git a/test/DataLayouts/unit_linear_indexing.jl b/test/DataLayouts/unit_linear_indexing.jl new file mode 100644 index 0000000000..8685f1605e --- /dev/null +++ b/test/DataLayouts/unit_linear_indexing.jl @@ -0,0 +1,191 @@ +#= +julia --check-bounds=yes --project +using Revise; include(joinpath("test", "DataLayouts", "unit_linear_indexing.jl")) +=# +using Test +using ClimaCore.DataLayouts +using ClimaCore.DataLayouts: get_struct_linear +import ClimaCore.Geometry +# import ClimaComms +using StaticArrays +# ClimaComms.@import_required_backends +import Random +Random.seed!(1234) + +offset_indices( + ::Type{FT}, + ::Type{S}, + ::Val{D}, + start_index::Integer, + ss::DataLayouts.StaticSize, +) where {FT, S, D} = map( + i -> DL.offset_index( + start_index, + Val(D), + DL.fieldtypeoffset(FT, S, Val(i)), + ss, + ), + 1:fieldcount(S), +) +import ClimaCore.DataLayouts as DL +field_dim_to_one(s, dim) = Tuple(map(j -> j == dim ? 1 : s[j], 1:length(s))) + +Base.@propagate_inbounds cart_ind(n::NTuple, i::Integer) = + @inbounds CartesianIndices(map(x -> Base.OneTo(x), n))[i] +Base.@propagate_inbounds linear_ind(n::NTuple, ci::CartesianIndex) = + @inbounds LinearIndices(map(x -> Base.OneTo(x), n))[ci] +Base.@propagate_inbounds linear_ind(n::NTuple, loc::NTuple) = + linear_ind(n, CartesianIndex(loc)) + +function debug_get_struct_linear(args...; expect_test_throws = false) + if expect_test_throws + get_struct_linear(args...) + else + try + get_struct_linear(args...) + catch + get_struct_linear(args...) + end + end +end + +function one_to_n(a::Array) + for i in 1:length(a) + a[i] = i + end + return a +end +one_to_n(s::Tuple, ::Type{FT}) where {FT} = one_to_n(zeros(FT, s...)) +ncomponents(::Type{FT}, ::Type{S}) where {FT, S} = div(sizeof(S), sizeof(FT)) + +struct Foo{T} + x::T + y::T +end + +Base.zero(::Type{Foo{T}}) where {T} = Foo{T}(0, 0) + +@testset "get_struct - IFH indexing (float)" begin + FT = Float64 + S = FT + s_array = (3, 1, 4) + @test ncomponents(FT, S) == 1 + a = one_to_n(s_array, FT) + ss = DataLayouts.StaticSize(s_array, 2) + @test debug_get_struct_linear(a, S, Val(2), 1, ss) == 1.0 + @test debug_get_struct_linear(a, S, Val(2), 2, ss) == 2.0 + @test debug_get_struct_linear(a, S, Val(2), 3, ss) == 3.0 + @test debug_get_struct_linear(a, S, Val(2), 4, ss) == 4.0 + @test debug_get_struct_linear(a, S, Val(2), 5, ss) == 5.0 + @test debug_get_struct_linear(a, S, Val(2), 6, ss) == 6.0 + @test debug_get_struct_linear(a, S, Val(2), 7, ss) == 7.0 + @test debug_get_struct_linear(a, S, Val(2), 8, ss) == 8.0 + @test debug_get_struct_linear(a, S, Val(2), 9, ss) == 9.0 + @test debug_get_struct_linear(a, S, Val(2), 10, ss) == 10.0 + @test debug_get_struct_linear(a, S, Val(2), 11, ss) == 11.0 + @test debug_get_struct_linear(a, S, Val(2), 12, ss) == 12.0 + @test_throws BoundsError debug_get_struct_linear( + a, + S, + Val(2), + 13, + ss; + expect_test_throws = true, + ) +end + +@testset "get_struct - IFH indexing" begin + FT = Float64 + S = Foo{FT} + s_array = (3, 2, 4) + @test ncomponents(FT, S) == 2 + a = one_to_n(s_array, FT) + ss = DataLayouts.StaticSize(s_array, 2) + @test debug_get_struct_linear(a, S, Val(2), 1, ss) == Foo{FT}(1.0, 4.0) + @test debug_get_struct_linear(a, S, Val(2), 2, ss) == Foo{FT}(2.0, 5.0) + @test debug_get_struct_linear(a, S, Val(2), 3, ss) == Foo{FT}(3.0, 6.0) + @test debug_get_struct_linear(a, S, Val(2), 4, ss) == Foo{FT}(7.0, 10.0) + @test debug_get_struct_linear(a, S, Val(2), 5, ss) == Foo{FT}(8.0, 11.0) + @test debug_get_struct_linear(a, S, Val(2), 6, ss) == Foo{FT}(9.0, 12.0) + @test debug_get_struct_linear(a, S, Val(2), 7, ss) == Foo{FT}(13.0, 16.0) + @test debug_get_struct_linear(a, S, Val(2), 8, ss) == Foo{FT}(14.0, 17.0) + @test debug_get_struct_linear(a, S, Val(2), 9, ss) == Foo{FT}(15.0, 18.0) + @test debug_get_struct_linear(a, S, Val(2), 10, ss) == Foo{FT}(19.0, 22.0) + @test debug_get_struct_linear(a, S, Val(2), 11, ss) == Foo{FT}(20.0, 23.0) + @test debug_get_struct_linear(a, S, Val(2), 12, ss) == Foo{FT}(21.0, 24.0) + @test_throws BoundsError debug_get_struct_linear( + a, + S, + Val(2), + 13, + ss; + expect_test_throws = true, + ) +end + +@testset "get_struct - IJF indexing" begin + FT = Float64 + S = Foo{FT} + s_array = (3, 4, 2) + @test ncomponents(FT, S) == 2 + s = field_dim_to_one(s_array, 3) + a = one_to_n(s_array, FT) + ss = DataLayouts.StaticSize(s_array, 3) + @test debug_get_struct_linear(a, S, Val(3), 1, ss) == Foo{FT}(1.0, 13.0) + @test debug_get_struct_linear(a, S, Val(3), 2, ss) == Foo{FT}(2.0, 14.0) + @test debug_get_struct_linear(a, S, Val(3), 3, ss) == Foo{FT}(3.0, 15.0) + @test debug_get_struct_linear(a, S, Val(3), 4, ss) == Foo{FT}(4.0, 16.0) + @test debug_get_struct_linear(a, S, Val(3), 5, ss) == Foo{FT}(5.0, 17.0) + @test debug_get_struct_linear(a, S, Val(3), 6, ss) == Foo{FT}(6.0, 18.0) + @test debug_get_struct_linear(a, S, Val(3), 7, ss) == Foo{FT}(7.0, 19.0) + @test debug_get_struct_linear(a, S, Val(3), 8, ss) == Foo{FT}(8.0, 20.0) + @test debug_get_struct_linear(a, S, Val(3), 9, ss) == Foo{FT}(9.0, 21.0) + @test debug_get_struct_linear(a, S, Val(3), 10, ss) == Foo{FT}(10.0, 22.0) + @test debug_get_struct_linear(a, S, Val(3), 11, ss) == Foo{FT}(11.0, 23.0) + @test debug_get_struct_linear(a, S, Val(3), 12, ss) == Foo{FT}(12.0, 24.0) + @test_throws BoundsError debug_get_struct_linear( + a, + S, + Val(3), + 13, + ss; + expect_test_throws = true, + ) +end + +@testset "get_struct - VIJFH indexing" begin + FT = Float64 + S = Foo{FT} + s_array = (2, 2, 2, 2, 2) + @test ncomponents(FT, S) == 2 + s = field_dim_to_one(s_array, 4) + a = one_to_n(s_array, FT) + ss = DataLayouts.StaticSize(s_array, 4) + + @test debug_get_struct_linear(a, S, Val(4), 1, ss) == Foo{FT}(1.0, 9.0) + @test debug_get_struct_linear(a, S, Val(4), 2, ss) == Foo{FT}(2.0, 10.0) + @test debug_get_struct_linear(a, S, Val(4), 3, ss) == Foo{FT}(3.0, 11.0) + @test debug_get_struct_linear(a, S, Val(4), 4, ss) == Foo{FT}(4.0, 12.0) + @test debug_get_struct_linear(a, S, Val(4), 5, ss) == Foo{FT}(5.0, 13.0) + @test debug_get_struct_linear(a, S, Val(4), 6, ss) == Foo{FT}(6.0, 14.0) + @test debug_get_struct_linear(a, S, Val(4), 7, ss) == Foo{FT}(7.0, 15.0) + @test debug_get_struct_linear(a, S, Val(4), 8, ss) == Foo{FT}(8.0, 16.0) + @test debug_get_struct_linear(a, S, Val(4), 9, ss) == Foo{FT}(17.0, 25.0) + @test debug_get_struct_linear(a, S, Val(4), 10, ss) == Foo{FT}(18.0, 26.0) + @test debug_get_struct_linear(a, S, Val(4), 11, ss) == Foo{FT}(19.0, 27.0) + @test debug_get_struct_linear(a, S, Val(4), 12, ss) == Foo{FT}(20.0, 28.0) + @test debug_get_struct_linear(a, S, Val(4), 13, ss) == Foo{FT}(21.0, 29.0) + @test debug_get_struct_linear(a, S, Val(4), 14, ss) == Foo{FT}(22.0, 30.0) + @test debug_get_struct_linear(a, S, Val(4), 15, ss) == Foo{FT}(23.0, 31.0) + @test debug_get_struct_linear(a, S, Val(4), 16, ss) == Foo{FT}(24.0, 32.0) + @test_throws BoundsError debug_get_struct_linear( + a, + S, + Val(4), + 17, + ss; + expect_test_throws = true, + ) +end + +# # TODO: add set_struct! diff --git a/test/runtests.jl b/test/runtests.jl index f8540cb4e5..4d35eaebb6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ include("tabulated_tests.jl") unit_tests = [ UnitTest("DataLayouts fill" ,"DataLayouts/unit_fill.jl"), UnitTest("DataLayouts ndims" ,"DataLayouts/unit_ndims.jl"), +UnitTest("DataLayouts has_uniform_datalayouts" ,"DataLayouts/has_uniform_datalayouts.jl"), UnitTest("DataLayouts get_struct" ,"DataLayouts/unit_struct.jl"), UnitTest("Recursive" ,"RecursiveApply/unit_recursive_apply.jl"), UnitTest("PlusHalf" ,"Utilities/unit_plushalf.jl"),