Skip to content

Commit

Permalink
CPU-GPU specialization -> dispatching
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 1, 2024
1 parent 4d9e9cd commit 9667382
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 57 deletions.
3 changes: 0 additions & 3 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ import ClimaCore.Utilities: cart_ind, linear_ind
import ClimaCore.RecursiveApply:
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax

const CuArrayBackedTypes =
Union{CUDA.CuArray, SubArray{<:Any, <:Any, <:CUDA.CuArray}}

include(joinpath("cuda", "cuda_utils.jl"))
include(joinpath("cuda", "data_layouts.jl"))
include(joinpath("cuda", "fields.jl"))
Expand Down
12 changes: 8 additions & 4 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S, Nv, Nij, <:CuArrayBackedTypes},
dest1::VIJFH{S, Nv, Nij},
::CUDA.CuArray,
) where {S, Nv, Nij}
_, _, _, _, Nh = size(dest1)
if Nv > 0 && Nh > 0
Expand All @@ -92,7 +93,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::IJFH{S, Nij, <:CuArrayBackedTypes},
dest1::IJFH{S, Nij},
::CUDA.CuArray,
) where {S, Nij}
_, _, _, _, Nh = size(dest1)
if Nh > 0
Expand All @@ -108,7 +110,8 @@ function fused_copyto!(
end
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VF{S, Nv, <:CuArrayBackedTypes},
dest1::VF{S, Nv},
::CUDA.CuArray,
) where {S, Nv}
_, _, _, _, Nh = size(dest1)
if Nv > 0 && Nh > 0
Expand All @@ -125,7 +128,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::DataF{S, <:CuArrayBackedTypes},
dest1::DataF{S},
::CUDA.CuArray,
) where {S}
auto_launch!(
knl_fused_copyto!,
Expand Down
50 changes: 19 additions & 31 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
DataLayouts._backed_array(x::CUDA.CuArray) = x

function knl_copyto!(dest, src)

i = CUDA.threadIdx().x
Expand All @@ -15,8 +17,9 @@ end

function Base.copyto!(
dest::IJFH{S, Nij},
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, A},
) where {S, Nij, A <: CuArrayBackedTypes}
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij},
::CUDA.CuArray,
) where {S, Nij}
_, _, _, _, Nh = size(bc)
if Nh > 0
auto_launch!(
Expand All @@ -32,8 +35,9 @@ end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij},
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, A},
) where {S, Nv, Nij, A <: CuArrayBackedTypes}
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij},
::CUDA.CuArray,
) where {S, Nv, Nij}
_, _, _, _, Nh = size(bc)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Expand All @@ -49,27 +53,11 @@ function Base.copyto!(
return dest
end

function Base.copyto!(
dest::VF{S, Nv},
bc::DataLayouts.BroadcastedUnionVF{S, Nv, A},
) where {S, Nv, A <: CuArrayBackedTypes}
_, _, _, _, Nh = size(dest)
if Nv > 0 && Nh > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (1, 1),
blocks_s = (Nh, Nv),
)
end
return dest
end

function Base.copyto!(
dest::DataF{S},
bc::DataLayouts.BroadcastedUnionDataF{S, A},
) where {S, A <: CUDA.CuArray}
bc::DataLayouts.BroadcastedUnionDataF{S},
::CUDA.CuArray,
) where {S}
auto_launch!(
knl_copyto!,
(dest, bc),
Expand Down Expand Up @@ -104,12 +92,12 @@ end
# TODO: can we use CUDA's luanch configuration for all data layouts?
# Currently, it seems to have a slight performance degredation.
#! format: off
# Base.copyto!(dest::IJFH{S, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, <:CuArrayBackedTypes}) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, <:CuArrayBackedTypes}) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij, <:CuArrayBackedTypes}) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIF{S, Ni, <:CuArrayBackedTypes}) where {S, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIFH{S, Nv, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, <:CuArrayBackedTypes}) where {S, Nv, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIJFH{S, Nv, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, <:CuArrayBackedTypes}) where {S, Nv, Nij} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VF{S, Nv, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVF{S, Nv, <:CuArrayBackedTypes}) where {S, Nv} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::DataF{S, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionDataF{S, <:CuArrayBackedTypes}) where {S} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij}, ::CUDA.CuArray) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni}, ::CUDA.CuArray) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::CUDA.CuArray) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::CUDA.CuArray) where {S, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIFH{S, Nv, Ni}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni}, ::CUDA.CuArray) where {S, Nv, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIJFH{S, Nv, Nij}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij}, ::CUDA.CuArray) where {S, Nv, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::CUDA.CuArray) where {S, Nv} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::CUDA.CuArray) where {S} = cuda_copyto!(dest, bc)
#! format: on
16 changes: 8 additions & 8 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ function cuda_fill!(dest::AbstractData, val)
end

#! format: off
Base.fill!(dest::IJFH{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IFH{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IJF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VIFH{<:Any, <:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VIJFH{<:Any, <:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::DataF{<:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IJFH{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::IFH{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::IJF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::IF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::VIFH{<:Any, <:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::VIJFH{<:Any, <:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::VF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::DataF{<:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
#! format: on
31 changes: 23 additions & 8 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ function Base.size(data::IJFH{S, Nij}) where {S, Nij}
(Nij, Nij, 1, Nv, Nh)
end

function Base.fill!(data::IJFH, val)
function Base.fill!(data::IJFH, val, ::Array)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
Expand Down Expand Up @@ -494,7 +494,7 @@ function Base.size(data::IFH{S, Ni}) where {S, Ni}
(Ni, 1, 1, Nv, Nh)
end

function Base.fill!(data::IFH, val)
function Base.fill!(data::IFH, val, ::Array)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
Expand Down Expand Up @@ -618,7 +618,7 @@ function DataF(x::T) where {T}
end


function Base.fill!(data::DataF, val)
function Base.fill!(data::DataF, val, ::Array)
@inbounds data[] = val
return data
end
Expand Down Expand Up @@ -746,7 +746,7 @@ end
function Base.size(data::IJF{S, Nij}) where {S, Nij}
return (Nij, Nij, 1, 1, 1)
end
function Base.fill!(data::IJF{S, Nij}, val) where {S, Nij}
function Base.fill!(data::IJF{S, Nij}, val, ::Array) where {S, Nij}
@inbounds for j in 1:Nij, i in 1:Nij
data[i, j] = val
end
Expand Down Expand Up @@ -884,7 +884,7 @@ function replace_basetype(data::IF{S, Ni}, ::Type{T}) where {S, Ni, T}
return IF{S′, Ni}(similar(array, T))
end

function Base.fill!(data::IF{S, Ni}, val) where {S, Ni}
function Base.fill!(data::IF{S, Ni}, val, ::Array) where {S, Ni}
@inbounds for i in 1:Ni
data[i] = val
end
Expand Down Expand Up @@ -998,7 +998,7 @@ Base.size(data::VF{S, Nv}) where {S, Nv} = (1, 1, 1, Nv, 1)

nlevels(::VF{S, Nv}) where {S, Nv} = Nv

function Base.fill!(data::VF, val)
function Base.fill!(data::VF, val, ::Array)
Nv = nlevels(data)
@inbounds for v in 1:Nv
data[v] = val
Expand Down Expand Up @@ -1123,7 +1123,7 @@ function Base.length(data::VIJFH)
size(parent(data), 1) * size(parent(data), 5)
end

function Base.fill!(data::VIJFH, val)
function Base.fill!(data::VIJFH, val, ::Array)
(Ni, Nj, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
Expand Down Expand Up @@ -1290,7 +1290,7 @@ end
function Base.length(data::VIFH)
nlevels(data) * size(parent(data), 4)
end
function Base.fill!(data::VIFH, val)
function Base.fill!(data::VIFH, val, ::Array)
(Ni, _, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
Expand Down Expand Up @@ -1610,4 +1610,19 @@ array2data(
::VIJFH{<:Any, Nv, Nij},
) where {T, Nv, Nij} = VIJFH{T, Nv, Nij}(reshape(array, Nv, Nij, Nij, 1, :))

"""
backed_array(data::AbstractData)
Returns an `Array` or a `CUDA.CuArray` depending
on how `data` is backed.
"""
backed_array(dest::AbstractData) = _backed_array(dest)

_backed_array(x::Array) = x
_backed_array(x::SubArray) = _backed_array(parent(x))
_backed_array(x::Base.ReshapedArray) = _backed_array(parent(x))
_backed_array(x::AbstractData) = _backed_array(parent(x))

Base.fill!(dest::AbstractData, val) = Base.fill!(dest, val, backed_array(dest))

end # module
28 changes: 26 additions & 2 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,18 @@ function Base.mapreduce(
end
end

Base.copyto!(
dest::AbstractData,
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
) = Base.copyto!(dest, bc, backed_array(dest))

# broadcasting scalar assignment
# Performance optimization for the common identity scalar case: dest .= val

function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
::Array,
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
Expand All @@ -500,6 +507,7 @@ end
function Base.copyto!(
dest::DataF{S},
bc::BroadcastedUnionDataF{S, A},
::Array,
) where {S, A}
@inbounds dest[] = convert(S, bc[])
return dest
Expand All @@ -508,6 +516,7 @@ end
function Base.copyto!(
dest::IJFH{S, Nij},
bc::BroadcastedUnionIJFH{S, Nij},
::Array,
) where {S, Nij}
_, _, _, _, Nh = size(bc)
@inbounds for h in 1:Nh
Expand All @@ -521,6 +530,7 @@ end
function Base.copyto!(
dest::IFH{S, Ni},
bc::BroadcastedUnionIFH{S, Ni},
::Array,
) where {S, Ni}
_, _, _, _, Nh = size(bc)
@inbounds for h in 1:Nh
Expand All @@ -535,6 +545,7 @@ end
function Base.copyto!(
dest::IJF{S, Nij},
bc::BroadcastedUnionIJF{S, Nij, A},
::Array,
) where {S, Nij, A}
@inbounds for j in 1:Nij, i in 1:Nij
idx = CartesianIndex(i, j, 1, 1, 1)
Expand All @@ -546,6 +557,7 @@ end
function Base.copyto!(
dest::IF{S, Ni},
bc::BroadcastedUnionIF{S, Ni, A},
::Array,
) where {S, Ni, A}
@inbounds for i in 1:Ni
idx = CartesianIndex(i, 1, 1, 1, 1)
Expand All @@ -558,6 +570,7 @@ end
function Base.copyto!(
dest::IF{S, Ni},
bc::Base.Broadcast.Broadcasted{IFStyle{Ni, A}},
::Array,
) where {S, Ni, A}
@inbounds for i in 1:Ni
idx = CartesianIndex(i, 1, 1, 1, 1)
Expand All @@ -570,6 +583,7 @@ end
function Base.copyto!(
dest::VF{S, Nv},
bc::BroadcastedUnionVF{S, Nv, A},
::Array,
) where {S, Nv, A}
@inbounds for v in 1:Nv
idx = CartesianIndex(1, 1, 1, v, 1)
Expand All @@ -581,6 +595,7 @@ end
function Base.copyto!(
dest::VIFH{S, Nv, Ni},
bc::BroadcastedUnionVIFH{S, Nv, Ni},
::Array,
) where {S, Nv, Ni}
(_, _, _, _, Nh) = size(bc)
# copy contiguous columns
Expand All @@ -595,6 +610,7 @@ end
function Base.copyto!(
dest::VIJFH{S, Nv, Nij},
bc::BroadcastedUnionVIJFH{S, Nv, Nij},
::Array,
) where {S, Nv, Nij}
# copy contiguous columns
_, _, _, _, Nh = size(dest)
Expand Down Expand Up @@ -635,12 +651,13 @@ function Base.copyto!(
end,
)
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
fused_copyto!(fmb_inst, dest1)
fused_copyto!(fmb_inst, dest1, backed_array(dest1))
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S1, Nv1, Nij},
::Array,
) where {S1, Nv1, Nij}
_, _, _, _, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
Expand All @@ -657,6 +674,7 @@ end
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::IJFH{S, Nij},
::Array,
) where {S, Nij}
# copy contiguous columns
_, _, _, Nv, Nh = size(dest1)
Expand All @@ -673,6 +691,7 @@ end
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIFH{S, Nv1, Ni},
::Array,
) where {S, Nv1, Ni}
# copy contiguous columns
_, _, _, _, Nh = size(dest1)
Expand All @@ -689,6 +708,7 @@ end
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VF{S1, Nv1},
::Array,
) where {S1, Nv1}
for (dest, bc) in fmbc.pairs
@inbounds for v in 1:Nv1
Expand All @@ -700,7 +720,11 @@ function fused_copyto!(
return nothing
end

function fused_copyto!(fmbc::FusedMultiBroadcast, dest::DataF{S}) where {S}
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest::DataF{S},
::Array,
) where {S}
for (dest, bc) in fmbc.pairs
@inbounds dest[] = convert(S, bc[])
end
Expand Down
2 changes: 1 addition & 1 deletion src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function todata(bc::Base.Broadcast.Broadcasted{FieldStyle{DS}}) where {DS}
Base.Broadcast.Broadcasted{DS}(bc.f, _args)
end

# same logic as Base.Broadcasted (which only defines it for Tuples)
# same logic as Base.Broadcast.Broadcasted (which only defines it for Tuples)
Base.axes(bc::Base.Broadcast.Broadcasted{<:AbstractFieldStyle}) =
_axes(bc, bc.axes)
_axes(bc, ::Nothing) = Base.Broadcast.combine_axes(bc.args...)
Expand Down
Loading

0 comments on commit 9667382

Please sign in to comment.