diff --git a/ext/ClimaCoreCUDAExt.jl b/ext/ClimaCoreCUDAExt.jl index 665da8fbec..8ff9441c02 100644 --- a/ext/ClimaCoreCUDAExt.jl +++ b/ext/ClimaCoreCUDAExt.jl @@ -16,9 +16,6 @@ import ClimaCore.Utilities: cart_ind, linear_ind import ClimaCore.RecursiveApply: ⊠, ⊞, ⊟, radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax -const CuArrayBackedTypes = - Union{CUDA.CuArray, SubArray{<:Any, <:Any, <:CUDA.CuArray}} - include(joinpath("cuda", "cuda_utils.jl")) include(joinpath("cuda", "data_layouts.jl")) include(joinpath("cuda", "fields.jl")) diff --git a/ext/cuda/data_layouts.jl b/ext/cuda/data_layouts.jl index aa1e0f7ef7..1e786c5788 100644 --- a/ext/cuda/data_layouts.jl +++ b/ext/cuda/data_layouts.jl @@ -73,7 +73,8 @@ end function fused_copyto!( fmbc::FusedMultiBroadcast, - dest1::VIJFH{S, Nv, Nij, <:CuArrayBackedTypes}, + dest1::VIJFH{S, Nv, Nij}, + ::CUDA.CuArray, ) where {S, Nv, Nij} _, _, _, _, Nh = size(dest1) if Nv > 0 && Nh > 0 @@ -92,7 +93,8 @@ end function fused_copyto!( fmbc::FusedMultiBroadcast, - dest1::IJFH{S, Nij, <:CuArrayBackedTypes}, + dest1::IJFH{S, Nij}, + ::CUDA.CuArray, ) where {S, Nij} _, _, _, _, Nh = size(dest1) if Nh > 0 @@ -108,7 +110,8 @@ function fused_copyto!( end function fused_copyto!( fmbc::FusedMultiBroadcast, - dest1::VF{S, Nv, <:CuArrayBackedTypes}, + dest1::VF{S, Nv}, + ::CUDA.CuArray, ) where {S, Nv} _, _, _, _, Nh = size(dest1) if Nv > 0 && Nh > 0 @@ -125,7 +128,8 @@ end function fused_copyto!( fmbc::FusedMultiBroadcast, - dest1::DataF{S, <:CuArrayBackedTypes}, + dest1::DataF{S}, + ::CUDA.CuArray, ) where {S} auto_launch!( knl_fused_copyto!, diff --git a/ext/cuda/data_layouts_copyto.jl b/ext/cuda/data_layouts_copyto.jl index da40d8cb3f..46188ea078 100644 --- a/ext/cuda/data_layouts_copyto.jl +++ b/ext/cuda/data_layouts_copyto.jl @@ -1,3 +1,5 @@ +DataLayouts._backed_array(x::CUDA.CuArray) = x + function knl_copyto!(dest, src) i = CUDA.threadIdx().x @@ -15,8 +17,9 @@ end function Base.copyto!( dest::IJFH{S, Nij}, - bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, A}, -) where {S, Nij, A <: CuArrayBackedTypes} + bc::DataLayouts.BroadcastedUnionIJFH{S, Nij}, + ::CUDA.CuArray, +) where {S, Nij} _, _, _, _, Nh = size(bc) if Nh > 0 auto_launch!( @@ -32,8 +35,9 @@ end function Base.copyto!( dest::VIJFH{S, Nv, Nij}, - bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, A}, -) where {S, Nv, Nij, A <: CuArrayBackedTypes} + bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij}, + ::CUDA.CuArray, +) where {S, Nv, Nij} _, _, _, _, Nh = size(bc) if Nv > 0 && Nh > 0 Nv_per_block = min(Nv, fld(256, Nij * Nij)) @@ -49,27 +53,11 @@ function Base.copyto!( return dest end -function Base.copyto!( - dest::VF{S, Nv}, - bc::DataLayouts.BroadcastedUnionVF{S, Nv, A}, -) where {S, Nv, A <: CuArrayBackedTypes} - _, _, _, _, Nh = size(dest) - if Nv > 0 && Nh > 0 - auto_launch!( - knl_copyto!, - (dest, bc), - dest; - threads_s = (1, 1), - blocks_s = (Nh, Nv), - ) - end - return dest -end - function Base.copyto!( dest::DataF{S}, - bc::DataLayouts.BroadcastedUnionDataF{S, A}, -) where {S, A <: CUDA.CuArray} + bc::DataLayouts.BroadcastedUnionDataF{S}, + ::CUDA.CuArray, +) where {S} auto_launch!( knl_copyto!, (dest, bc), @@ -104,12 +92,12 @@ end # TODO: can we use CUDA's luanch configuration for all data layouts? # Currently, it seems to have a slight performance degredation. #! format: off -# Base.copyto!(dest::IJFH{S, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, <:CuArrayBackedTypes}) where {S, Nij} = cuda_copyto!(dest, bc) -Base.copyto!(dest::IFH{S, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, <:CuArrayBackedTypes}) where {S, Ni} = cuda_copyto!(dest, bc) -Base.copyto!(dest::IJF{S, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij, <:CuArrayBackedTypes}) where {S, Nij} = cuda_copyto!(dest, bc) -Base.copyto!(dest::IF{S, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIF{S, Ni, <:CuArrayBackedTypes}) where {S, Ni} = cuda_copyto!(dest, bc) -# Base.copyto!(dest::VIFH{S, Nv, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, <:CuArrayBackedTypes}) where {S, Nv, Ni} = cuda_copyto!(dest, bc) -# Base.copyto!(dest::VIJFH{S, Nv, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, <:CuArrayBackedTypes}) where {S, Nv, Nij} = cuda_copyto!(dest, bc) -# Base.copyto!(dest::VF{S, Nv, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVF{S, Nv, <:CuArrayBackedTypes}) where {S, Nv} = cuda_copyto!(dest, bc) -# Base.copyto!(dest::DataF{S, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionDataF{S, <:CuArrayBackedTypes}) where {S} = cuda_copyto!(dest, bc) +# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij}, ::CUDA.CuArray) where {S, Nij} = cuda_copyto!(dest, bc) +Base.copyto!(dest::IFH{S, Ni}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni}, ::CUDA.CuArray) where {S, Ni} = cuda_copyto!(dest, bc) +Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::CUDA.CuArray) where {S, Nij} = cuda_copyto!(dest, bc) +Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::CUDA.CuArray) where {S, Ni} = cuda_copyto!(dest, bc) +# Base.copyto!(dest::VIFH{S, Nv, Ni}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni}, ::CUDA.CuArray) where {S, Nv, Ni} = cuda_copyto!(dest, bc) +# Base.copyto!(dest::VIJFH{S, Nv, Nij}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij}, ::CUDA.CuArray) where {S, Nv, Nij} = cuda_copyto!(dest, bc) +Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::CUDA.CuArray) where {S, Nv} = cuda_copyto!(dest, bc) +# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::CUDA.CuArray) where {S} = cuda_copyto!(dest, bc) #! format: on diff --git a/ext/cuda/data_layouts_fill.jl b/ext/cuda/data_layouts_fill.jl index 963ec2be86..3e9dfbef07 100644 --- a/ext/cuda/data_layouts_fill.jl +++ b/ext/cuda/data_layouts_fill.jl @@ -19,12 +19,12 @@ function cuda_fill!(dest::AbstractData, val) end #! format: off -Base.fill!(dest::IJFH{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val) -Base.fill!(dest::IFH{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val) -Base.fill!(dest::IJF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val) -Base.fill!(dest::IF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val) -Base.fill!(dest::VIFH{<:Any, <:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val) -Base.fill!(dest::VIJFH{<:Any, <:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val) -Base.fill!(dest::VF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val) -Base.fill!(dest::DataF{<:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val) +Base.fill!(dest::IJFH{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val) +Base.fill!(dest::IFH{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val) +Base.fill!(dest::IJF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val) +Base.fill!(dest::IF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val) +Base.fill!(dest::VIFH{<:Any, <:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val) +Base.fill!(dest::VIJFH{<:Any, <:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val) +Base.fill!(dest::VF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val) +Base.fill!(dest::DataF{<:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val) #! format: on diff --git a/src/DataLayouts/DataLayouts.jl b/src/DataLayouts/DataLayouts.jl index 758ccd5725..b50e630189 100644 --- a/src/DataLayouts/DataLayouts.jl +++ b/src/DataLayouts/DataLayouts.jl @@ -325,7 +325,7 @@ function Base.size(data::IJFH{S, Nij}) where {S, Nij} (Nij, Nij, 1, Nv, Nh) end -function Base.fill!(data::IJFH, val) +function Base.fill!(data::IJFH, val, ::Array) (_, _, _, _, Nh) = size(data) @inbounds for h in 1:Nh fill!(slab(data, h), val) @@ -494,7 +494,7 @@ function Base.size(data::IFH{S, Ni}) where {S, Ni} (Ni, 1, 1, Nv, Nh) end -function Base.fill!(data::IFH, val) +function Base.fill!(data::IFH, val, ::Array) (_, _, _, _, Nh) = size(data) @inbounds for h in 1:Nh fill!(slab(data, h), val) @@ -618,7 +618,7 @@ function DataF(x::T) where {T} end -function Base.fill!(data::DataF, val) +function Base.fill!(data::DataF, val, ::Array) @inbounds data[] = val return data end @@ -746,7 +746,7 @@ end function Base.size(data::IJF{S, Nij}) where {S, Nij} return (Nij, Nij, 1, 1, 1) end -function Base.fill!(data::IJF{S, Nij}, val) where {S, Nij} +function Base.fill!(data::IJF{S, Nij}, val, ::Array) where {S, Nij} @inbounds for j in 1:Nij, i in 1:Nij data[i, j] = val end @@ -884,7 +884,7 @@ function replace_basetype(data::IF{S, Ni}, ::Type{T}) where {S, Ni, T} return IF{S′, Ni}(similar(array, T)) end -function Base.fill!(data::IF{S, Ni}, val) where {S, Ni} +function Base.fill!(data::IF{S, Ni}, val, ::Array) where {S, Ni} @inbounds for i in 1:Ni data[i] = val end @@ -998,7 +998,7 @@ Base.size(data::VF{S, Nv}) where {S, Nv} = (1, 1, 1, Nv, 1) nlevels(::VF{S, Nv}) where {S, Nv} = Nv -function Base.fill!(data::VF, val) +function Base.fill!(data::VF, val, ::Array) Nv = nlevels(data) @inbounds for v in 1:Nv data[v] = val @@ -1123,7 +1123,7 @@ function Base.length(data::VIJFH) size(parent(data), 1) * size(parent(data), 5) end -function Base.fill!(data::VIJFH, val) +function Base.fill!(data::VIJFH, val, ::Array) (Ni, Nj, _, Nv, Nh) = size(data) @inbounds for h in 1:Nh, v in 1:Nv fill!(slab(data, v, h), val) @@ -1290,7 +1290,7 @@ end function Base.length(data::VIFH) nlevels(data) * size(parent(data), 4) end -function Base.fill!(data::VIFH, val) +function Base.fill!(data::VIFH, val, ::Array) (Ni, _, _, Nv, Nh) = size(data) @inbounds for h in 1:Nh, v in 1:Nv fill!(slab(data, v, h), val) @@ -1610,4 +1610,19 @@ array2data( ::VIJFH{<:Any, Nv, Nij}, ) where {T, Nv, Nij} = VIJFH{T, Nv, Nij}(reshape(array, Nv, Nij, Nij, 1, :)) +""" + backed_array(data::AbstractData) + +Returns an `Array` or a `CUDA.CuArray` depending +on how `data` is backed. +""" +backed_array(dest::AbstractData) = _backed_array(dest) + +_backed_array(x::Array) = x +_backed_array(x::SubArray) = _backed_array(parent(x)) +_backed_array(x::Base.ReshapedArray) = _backed_array(parent(x)) +_backed_array(x::AbstractData) = _backed_array(parent(x)) + +Base.fill!(dest::AbstractData, val) = Base.fill!(dest, val, backed_array(dest)) + end # module diff --git a/src/DataLayouts/broadcast.jl b/src/DataLayouts/broadcast.jl index c8cceda8b7..7b6bddef20 100644 --- a/src/DataLayouts/broadcast.jl +++ b/src/DataLayouts/broadcast.jl @@ -481,11 +481,18 @@ function Base.mapreduce( end end +Base.copyto!( + dest::AbstractData, + bc::Union{AbstractData, Base.Broadcast.Broadcasted}, +) = Base.copyto!(dest, bc, backed_array(dest)) + # broadcasting scalar assignment # Performance optimization for the common identity scalar case: dest .= val + function Base.copyto!( dest::AbstractData, bc::Base.Broadcast.Broadcasted{Style}, + ::Array, ) where { Style <: Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}}, @@ -500,6 +507,7 @@ end function Base.copyto!( dest::DataF{S}, bc::BroadcastedUnionDataF{S, A}, + ::Array, ) where {S, A} @inbounds dest[] = convert(S, bc[]) return dest @@ -508,6 +516,7 @@ end function Base.copyto!( dest::IJFH{S, Nij}, bc::BroadcastedUnionIJFH{S, Nij}, + ::Array, ) where {S, Nij} _, _, _, _, Nh = size(bc) @inbounds for h in 1:Nh @@ -521,6 +530,7 @@ end function Base.copyto!( dest::IFH{S, Ni}, bc::BroadcastedUnionIFH{S, Ni}, + ::Array, ) where {S, Ni} _, _, _, _, Nh = size(bc) @inbounds for h in 1:Nh @@ -535,6 +545,7 @@ end function Base.copyto!( dest::IJF{S, Nij}, bc::BroadcastedUnionIJF{S, Nij, A}, + ::Array, ) where {S, Nij, A} @inbounds for j in 1:Nij, i in 1:Nij idx = CartesianIndex(i, j, 1, 1, 1) @@ -546,6 +557,7 @@ end function Base.copyto!( dest::IF{S, Ni}, bc::BroadcastedUnionIF{S, Ni, A}, + ::Array, ) where {S, Ni, A} @inbounds for i in 1:Ni idx = CartesianIndex(i, 1, 1, 1, 1) @@ -558,6 +570,7 @@ end function Base.copyto!( dest::IF{S, Ni}, bc::Base.Broadcast.Broadcasted{IFStyle{Ni, A}}, + ::Array, ) where {S, Ni, A} @inbounds for i in 1:Ni idx = CartesianIndex(i, 1, 1, 1, 1) @@ -570,6 +583,7 @@ end function Base.copyto!( dest::VF{S, Nv}, bc::BroadcastedUnionVF{S, Nv, A}, + ::Array, ) where {S, Nv, A} @inbounds for v in 1:Nv idx = CartesianIndex(1, 1, 1, v, 1) @@ -581,6 +595,7 @@ end function Base.copyto!( dest::VIFH{S, Nv, Ni}, bc::BroadcastedUnionVIFH{S, Nv, Ni}, + ::Array, ) where {S, Nv, Ni} (_, _, _, _, Nh) = size(bc) # copy contiguous columns @@ -595,6 +610,7 @@ end function Base.copyto!( dest::VIJFH{S, Nv, Nij}, bc::BroadcastedUnionVIJFH{S, Nv, Nij}, + ::Array, ) where {S, Nv, Nij} # copy contiguous columns _, _, _, _, Nh = size(dest) @@ -635,12 +651,13 @@ function Base.copyto!( end, ) # check_fused_broadcast_axes(fmbc) # we should already have checked the axes - fused_copyto!(fmb_inst, dest1) + fused_copyto!(fmb_inst, dest1, backed_array(dest1)) end function fused_copyto!( fmbc::FusedMultiBroadcast, dest1::VIJFH{S1, Nv1, Nij}, + ::Array, ) where {S1, Nv1, Nij} _, _, _, _, Nh = size(dest1) for (dest, bc) in fmbc.pairs @@ -657,6 +674,7 @@ end function fused_copyto!( fmbc::FusedMultiBroadcast, dest1::IJFH{S, Nij}, + ::Array, ) where {S, Nij} # copy contiguous columns _, _, _, Nv, Nh = size(dest1) @@ -673,6 +691,7 @@ end function fused_copyto!( fmbc::FusedMultiBroadcast, dest1::VIFH{S, Nv1, Ni}, + ::Array, ) where {S, Nv1, Ni} # copy contiguous columns _, _, _, _, Nh = size(dest1) @@ -689,6 +708,7 @@ end function fused_copyto!( fmbc::FusedMultiBroadcast, dest1::VF{S1, Nv1}, + ::Array, ) where {S1, Nv1} for (dest, bc) in fmbc.pairs @inbounds for v in 1:Nv1 @@ -700,7 +720,11 @@ function fused_copyto!( return nothing end -function fused_copyto!(fmbc::FusedMultiBroadcast, dest::DataF{S}) where {S} +function fused_copyto!( + fmbc::FusedMultiBroadcast, + dest::DataF{S}, + ::Array, +) where {S} for (dest, bc) in fmbc.pairs @inbounds dest[] = convert(S, bc[]) end diff --git a/src/Fields/broadcast.jl b/src/Fields/broadcast.jl index b5c3f0cfcd..b7ce970794 100644 --- a/src/Fields/broadcast.jl +++ b/src/Fields/broadcast.jl @@ -129,7 +129,7 @@ function todata(bc::Base.Broadcast.Broadcasted{FieldStyle{DS}}) where {DS} Base.Broadcast.Broadcasted{DS}(bc.f, _args) end -# same logic as Base.Broadcasted (which only defines it for Tuples) +# same logic as Base.Broadcast.Broadcasted (which only defines it for Tuples) Base.axes(bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle}) = _axes(bc, bc.axes) _axes(bc, ::Nothing) = Base.Broadcast.combine_axes(bc.args...) diff --git a/test/DataLayouts/unit_fill.jl b/test/DataLayouts/unit_fill.jl index a25872e90f..c9a0c63141 100644 --- a/test/DataLayouts/unit_fill.jl +++ b/test/DataLayouts/unit_fill.jl @@ -98,3 +98,55 @@ end # data = DataLayouts.IJKFVH{S, Nij, Nk}(device_zeros(FT,Nij,Nij,Nk,Nf,Nv,Nh)); test_fill!(data, (2,3)) # TODO: test # data = DataLayouts.IH1JH2{S, Nij}(device_zeros(FT,2*Nij,3*Nij)); test_fill!(data, (2,3)) # TODO: test end + +@testset "Reshaped Arrays" begin + device = ClimaComms.device() + device_zeros(args...) = ClimaComms.array_type(device)(zeros(args...)) + function reshaped_array(data2) + # `reshape` does not always return a `ReshapedArray`, which + # we need to specialize on to correctly dispatch when its + # parent array is backed by a CuArray. So, let's first + # In order to get a ReshapedArray back, let's first create view + # via `data.:2`. This doesn't guarantee that the result is a + # ReshapedArray, but it works for several cases. Tests when + # are commented out for cases when Julia Base manages to return + # a parent-similar array. + data = data.:2 + array₀ = DataLayouts.data2array(data) + @test typeof(array₀) <: Base.ReshapedArray + rdata = DataLayouts.array2data(array₀, data) + newdata = DataLayouts.rebuild( + data, + SubArray( + parent(rdata), + ntuple(i -> Base.OneTo(size(parent(rdata), i)), ndims(rdata)), + ), + ) + rarray = parent(parent(newdata)) + @test typeof(rarray) <: Base.ReshapedArray + subarray = parent(rarray) + @test typeof(subarray) <: Base.SubArray + array = parent(subarray) + newdata + end + FT = Float64 + S = Tuple{FT, FT} # need at least 2 components to make a SubArray + Nf = 2 + Nv = 4 + Nij = 3 + Nh = 5 + Nk = 6 + # directly so that we can easily test all cases: +#! format: off + data = IJFH{S, Nij}(device_zeros(FT,Nij,Nij,Nf,Nh)); test_fill!(reshaped_array(data), 2) + data = IFH{S, Nij}(device_zeros(FT,Nij,Nf,Nh)); test_fill!(reshaped_array(data), 2) + # data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); test_fill!(reshaped_array(data), 2) + # data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); test_fill!(reshaped_array(data), 2) + # data = VF{S, Nv}(device_zeros(FT,Nv,Nf)); test_fill!(reshaped_array(data), 2) + data = VIJFH{S, Nv, Nij}(device_zeros(FT,Nv,Nij,Nij,Nf,Nh)); test_fill!(reshaped_array(data), 2) + data = VIFH{S, Nv, Nij}(device_zeros(FT,Nv,Nij,Nf,Nh)); test_fill!(reshaped_array(data), 2) +#! format: on + # TODO: test this + # data = DataLayouts.IJKFVH{S, Nij, Nk}(device_zeros(FT,Nij,Nij,Nk,Nf,Nv,Nh)); test_fill!(reshaped_array(data), 2) # TODO: test + # data = DataLayouts.IH1JH2{S, Nij}(device_zeros(FT,2*Nij,3*Nij)); test_fill!(reshaped_array(data), 2) # TODO: test +end