Skip to content

Commit

Permalink
CPU-GPU specialization -> dispatching
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 1, 2024
1 parent 4d9e9cd commit 8b1f400
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 57 deletions.
3 changes: 0 additions & 3 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ import ClimaCore.Utilities: cart_ind, linear_ind
import ClimaCore.RecursiveApply:
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax

const CuArrayBackedTypes =
Union{CUDA.CuArray, SubArray{<:Any, <:Any, <:CUDA.CuArray}}

include(joinpath("cuda", "cuda_utils.jl"))
include(joinpath("cuda", "data_layouts.jl"))
include(joinpath("cuda", "fields.jl"))
Expand Down
12 changes: 8 additions & 4 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S, Nv, Nij, <:CuArrayBackedTypes},
dest1::VIJFH{S, Nv, Nij},
::CUDA.CuArray,
) where {S, Nv, Nij}
_, _, _, _, Nh = size(dest1)
if Nv > 0 && Nh > 0
Expand All @@ -92,7 +93,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::IJFH{S, Nij, <:CuArrayBackedTypes},
dest1::IJFH{S, Nij},
::CUDA.CuArray,
) where {S, Nij}
_, _, _, _, Nh = size(dest1)
if Nh > 0
Expand All @@ -108,7 +110,8 @@ function fused_copyto!(
end
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VF{S, Nv, <:CuArrayBackedTypes},
dest1::VF{S, Nv},
::CUDA.CuArray,
) where {S, Nv}
_, _, _, _, Nh = size(dest1)
if Nv > 0 && Nh > 0
Expand All @@ -125,7 +128,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::DataF{S, <:CuArrayBackedTypes},
dest1::DataF{S},
::CUDA.CuArray,
) where {S}
auto_launch!(
knl_fused_copyto!,
Expand Down
50 changes: 19 additions & 31 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
DataLayouts._backed_array(x::CUDA.CuArray) = x

function knl_copyto!(dest, src)

i = CUDA.threadIdx().x
Expand All @@ -15,8 +17,9 @@ end

function Base.copyto!(
dest::IJFH{S, Nij},
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, A},
) where {S, Nij, A <: CuArrayBackedTypes}
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij},
::CUDA.CuArray,
) where {S, Nij}
_, _, _, _, Nh = size(bc)
if Nh > 0
auto_launch!(
Expand All @@ -32,8 +35,9 @@ end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij},
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, A},
) where {S, Nv, Nij, A <: CuArrayBackedTypes}
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij},
::CUDA.CuArray,
) where {S, Nv, Nij}
_, _, _, _, Nh = size(bc)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Expand All @@ -49,27 +53,11 @@ function Base.copyto!(
return dest
end

function Base.copyto!(
dest::VF{S, Nv},
bc::DataLayouts.BroadcastedUnionVF{S, Nv, A},
) where {S, Nv, A <: CuArrayBackedTypes}
_, _, _, _, Nh = size(dest)
if Nv > 0 && Nh > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (1, 1),
blocks_s = (Nh, Nv),
)
end
return dest
end

function Base.copyto!(
dest::DataF{S},
bc::DataLayouts.BroadcastedUnionDataF{S, A},
) where {S, A <: CUDA.CuArray}
bc::DataLayouts.BroadcastedUnionDataF{S},
::CUDA.CuArray,
) where {S}
auto_launch!(
knl_copyto!,
(dest, bc),
Expand Down Expand Up @@ -104,12 +92,12 @@ end
# TODO: can we use CUDA's luanch configuration for all data layouts?
# Currently, it seems to have a slight performance degredation.
#! format: off
# Base.copyto!(dest::IJFH{S, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, <:CuArrayBackedTypes}) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, <:CuArrayBackedTypes}) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij, <:CuArrayBackedTypes}) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIF{S, Ni, <:CuArrayBackedTypes}) where {S, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIFH{S, Nv, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, <:CuArrayBackedTypes}) where {S, Nv, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIJFH{S, Nv, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, <:CuArrayBackedTypes}) where {S, Nv, Nij} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VF{S, Nv, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVF{S, Nv, <:CuArrayBackedTypes}) where {S, Nv} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::DataF{S, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionDataF{S, <:CuArrayBackedTypes}) where {S} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij}, ::CUDA.CuArray) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni}, ::CUDA.CuArray) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::CUDA.CuArray) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::CUDA.CuArray) where {S, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIFH{S, Nv, Ni}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni}, ::CUDA.CuArray) where {S, Nv, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIJFH{S, Nv, Nij}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij}, ::CUDA.CuArray) where {S, Nv, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::CUDA.CuArray) where {S, Nv} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::CUDA.CuArray) where {S} = cuda_copyto!(dest, bc)
#! format: on
16 changes: 8 additions & 8 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ function cuda_fill!(dest::AbstractData, val)
end

#! format: off
Base.fill!(dest::IJFH{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IFH{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IJF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VIFH{<:Any, <:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VIJFH{<:Any, <:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::DataF{<:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IJFH{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::IFH{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::IJF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::IF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::VIFH{<:Any, <:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::VIJFH{<:Any, <:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::VF{<:Any, <:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
Base.fill!(dest::DataF{<:Any}, val, ::CUDA.CuArray) = cuda_fill!(dest, val)
#! format: on
49 changes: 41 additions & 8 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ function Base.size(data::IJFH{S, Nij}) where {S, Nij}
(Nij, Nij, 1, Nv, Nh)
end

function Base.fill!(data::IJFH, val)
function Base.fill!(data::IJFH, val, ::Array)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
Expand Down Expand Up @@ -494,7 +494,7 @@ function Base.size(data::IFH{S, Ni}) where {S, Ni}
(Ni, 1, 1, Nv, Nh)
end

function Base.fill!(data::IFH, val)
function Base.fill!(data::IFH, val, ::Array)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
Expand Down Expand Up @@ -618,7 +618,7 @@ function DataF(x::T) where {T}
end


function Base.fill!(data::DataF, val)
function Base.fill!(data::DataF, val, ::Array)
@inbounds data[] = val
return data
end
Expand Down Expand Up @@ -746,7 +746,7 @@ end
function Base.size(data::IJF{S, Nij}) where {S, Nij}
return (Nij, Nij, 1, 1, 1)
end
function Base.fill!(data::IJF{S, Nij}, val) where {S, Nij}
function Base.fill!(data::IJF{S, Nij}, val, ::Array) where {S, Nij}
@inbounds for j in 1:Nij, i in 1:Nij
data[i, j] = val
end
Expand Down Expand Up @@ -884,7 +884,7 @@ function replace_basetype(data::IF{S, Ni}, ::Type{T}) where {S, Ni, T}
return IF{S′, Ni}(similar(array, T))
end

function Base.fill!(data::IF{S, Ni}, val) where {S, Ni}
function Base.fill!(data::IF{S, Ni}, val, ::Array) where {S, Ni}
@inbounds for i in 1:Ni
data[i] = val
end
Expand Down Expand Up @@ -998,7 +998,7 @@ Base.size(data::VF{S, Nv}) where {S, Nv} = (1, 1, 1, Nv, 1)

nlevels(::VF{S, Nv}) where {S, Nv} = Nv

function Base.fill!(data::VF, val)
function Base.fill!(data::VF, val, ::Array)
Nv = nlevels(data)
@inbounds for v in 1:Nv
data[v] = val
Expand Down Expand Up @@ -1123,7 +1123,7 @@ function Base.length(data::VIJFH)
size(parent(data), 1) * size(parent(data), 5)
end

function Base.fill!(data::VIJFH, val)
function Base.fill!(data::VIJFH, val, ::Array)
(Ni, Nj, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
Expand Down Expand Up @@ -1290,7 +1290,7 @@ end
function Base.length(data::VIFH)
nlevels(data) * size(parent(data), 4)
end
function Base.fill!(data::VIFH, val)
function Base.fill!(data::VIFH, val, ::Array)
(Ni, _, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
Expand Down Expand Up @@ -1610,4 +1610,37 @@ array2data(
::VIJFH{<:Any, Nv, Nij},
) where {T, Nv, Nij} = VIJFH{T, Nv, Nij}(reshape(array, Nv, Nij, Nij, 1, :))

"""
backed_array(dest, bc::Union{AbstractData, Base.Broadcast.Broadcasted})
backed_array(dest, val::Real)
Returns an `Array` or a `CUDA.CuArray`
"""
function backed_array end

_backed_array(x::Array) = x
_backed_array(x::SubArray) = _backed_array(parent(x))
_backed_array(x::Base.ReshapedArray) = _backed_array(parent(x))
_backed_array(x::AbstractData) = _backed_array(parent(x))

function backed_array(
dest::AbstractData,
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
)
backed_dest = _backed_array(dest)
if isascalar(bc)
return backed_dest
else
if typeof(backed_dest) == typeof(_backed_array(bc))
return backed_dest
else
error("Mismatched array type in backed_array")
end
end
end

backed_array(dest::AbstractData) = _backed_array(dest)

Base.fill!(dest::AbstractData, val) = Base.fill!(dest, val, backed_array(dest))

end # module
Loading

0 comments on commit 8b1f400

Please sign in to comment.