Skip to content

Commit

Permalink
CPU-GPU specialization -> dispatching
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 1, 2024
1 parent 4d9e9cd commit 988ca7f
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 57 deletions.
4 changes: 1 addition & 3 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@ using CUDA
using CUDA: threadIdx, blockIdx, blockDim
import StaticArrays: SVector, SMatrix, SArray
import ClimaCore.DataLayouts: mapreduce_cuda
import ClimaCore.DataLayouts: ToCUDA
import ClimaCore.DataLayouts: slab, column
import ClimaCore.Utilities: half
import ClimaCore.Utilities: cart_ind, linear_ind
import ClimaCore.RecursiveApply:
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax

const CuArrayBackedTypes =
Union{CUDA.CuArray, SubArray{<:Any, <:Any, <:CUDA.CuArray}}

include(joinpath("cuda", "cuda_utils.jl"))
include(joinpath("cuda", "data_layouts.jl"))
include(joinpath("cuda", "fields.jl"))
Expand Down
12 changes: 8 additions & 4 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S, Nv, Nij, <:CuArrayBackedTypes},
dest1::VIJFH{S, Nv, Nij},
::ToCUDA,
) where {S, Nv, Nij}
_, _, _, _, Nh = size(dest1)
if Nv > 0 && Nh > 0
Expand All @@ -92,7 +93,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::IJFH{S, Nij, <:CuArrayBackedTypes},
dest1::IJFH{S, Nij},
::ToCUDA,
) where {S, Nij}
_, _, _, _, Nh = size(dest1)
if Nh > 0
Expand All @@ -108,7 +110,8 @@ function fused_copyto!(
end
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VF{S, Nv, <:CuArrayBackedTypes},
dest1::VF{S, Nv},
::ToCUDA,
) where {S, Nv}
_, _, _, _, Nh = size(dest1)
if Nv > 0 && Nh > 0
Expand All @@ -125,7 +128,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::DataF{S, <:CuArrayBackedTypes},
dest1::DataF{S},
::ToCUDA,
) where {S}
auto_launch!(
knl_fused_copyto!,
Expand Down
50 changes: 19 additions & 31 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
DataLayouts._device_dispatch(x::CUDA.CuArray) = DataLayouts.ToCUDA()

function knl_copyto!(dest, src)

i = CUDA.threadIdx().x
Expand All @@ -15,8 +17,9 @@ end

function Base.copyto!(
dest::IJFH{S, Nij},
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, A},
) where {S, Nij, A <: CuArrayBackedTypes}
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij},
::ToCUDA,
) where {S, Nij}
_, _, _, _, Nh = size(bc)
if Nh > 0
auto_launch!(
Expand All @@ -32,8 +35,9 @@ end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij},
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, A},
) where {S, Nv, Nij, A <: CuArrayBackedTypes}
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij},
::ToCUDA,
) where {S, Nv, Nij}
_, _, _, _, Nh = size(bc)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Expand All @@ -49,27 +53,11 @@ function Base.copyto!(
return dest
end

function Base.copyto!(
dest::VF{S, Nv},
bc::DataLayouts.BroadcastedUnionVF{S, Nv, A},
) where {S, Nv, A <: CuArrayBackedTypes}
_, _, _, _, Nh = size(dest)
if Nv > 0 && Nh > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (1, 1),
blocks_s = (Nh, Nv),
)
end
return dest
end

function Base.copyto!(
dest::DataF{S},
bc::DataLayouts.BroadcastedUnionDataF{S, A},
) where {S, A <: CUDA.CuArray}
bc::DataLayouts.BroadcastedUnionDataF{S},
::ToCUDA,
) where {S}
auto_launch!(
knl_copyto!,
(dest, bc),
Expand Down Expand Up @@ -104,12 +92,12 @@ end
# TODO: can we use CUDA's luanch configuration for all data layouts?
# Currently, it seems to have a slight performance degredation.
#! format: off
# Base.copyto!(dest::IJFH{S, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, <:CuArrayBackedTypes}) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, <:CuArrayBackedTypes}) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij, <:CuArrayBackedTypes}) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionIF{S, Ni, <:CuArrayBackedTypes}) where {S, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIFH{S, Nv, Ni, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, <:CuArrayBackedTypes}) where {S, Nv, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIJFH{S, Nv, Nij, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, <:CuArrayBackedTypes}) where {S, Nv, Nij} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VF{S, Nv, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionVF{S, Nv, <:CuArrayBackedTypes}) where {S, Nv} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::DataF{S, <:CuArrayBackedTypes}, bc::DataLayouts.BroadcastedUnionDataF{S, <:CuArrayBackedTypes}) where {S} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIFH{S, Nv, Ni}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni}, ::ToCUDA) where {S, Nv, Ni} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIJFH{S, Nv, Nij}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij}, ::ToCUDA) where {S, Nv, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
#! format: on
16 changes: 8 additions & 8 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ function cuda_fill!(dest::AbstractData, val)
end

#! format: off
Base.fill!(dest::IJFH{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IFH{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IJF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VIFH{<:Any, <:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VIJFH{<:Any, <:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::VF{<:Any, <:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::DataF{<:Any, <:CuArrayBackedTypes}, val) = cuda_fill!(dest, val)
Base.fill!(dest::IJFH{<:Any, <:Any}, val, ::ToCUDA) = cuda_fill!(dest, val)
Base.fill!(dest::IFH{<:Any, <:Any}, val, ::ToCUDA) = cuda_fill!(dest, val)
Base.fill!(dest::IJF{<:Any, <:Any}, val, ::ToCUDA) = cuda_fill!(dest, val)
Base.fill!(dest::IF{<:Any, <:Any}, val, ::ToCUDA) = cuda_fill!(dest, val)
Base.fill!(dest::VIFH{<:Any, <:Any, <:Any}, val, ::ToCUDA) = cuda_fill!(dest, val)
Base.fill!(dest::VIJFH{<:Any, <:Any, <:Any}, val, ::ToCUDA) = cuda_fill!(dest, val)
Base.fill!(dest::VF{<:Any, <:Any}, val, ::ToCUDA) = cuda_fill!(dest, val)
Base.fill!(dest::DataF{<:Any}, val, ::ToCUDA) = cuda_fill!(dest, val)
#! format: on
38 changes: 30 additions & 8 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ import Adapt
import ..slab, ..slab_args, ..column, ..column_args, ..level
export slab, column, level, IJFH, IJF, IFH, IF, VF, VIJFH, VIFH, DataF

abstract type AbstractDispatchToDevice end
struct ToCPU <: AbstractDispatchToDevice end
struct ToCUDA <: AbstractDispatchToDevice end

include("struct.jl")

abstract type AbstractData{S} end
Expand Down Expand Up @@ -325,7 +329,7 @@ function Base.size(data::IJFH{S, Nij}) where {S, Nij}
(Nij, Nij, 1, Nv, Nh)
end

function Base.fill!(data::IJFH, val)
function Base.fill!(data::IJFH, val, ::ToCPU)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
Expand Down Expand Up @@ -494,7 +498,7 @@ function Base.size(data::IFH{S, Ni}) where {S, Ni}
(Ni, 1, 1, Nv, Nh)
end

function Base.fill!(data::IFH, val)
function Base.fill!(data::IFH, val, ::ToCPU)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
Expand Down Expand Up @@ -618,7 +622,7 @@ function DataF(x::T) where {T}
end


function Base.fill!(data::DataF, val)
function Base.fill!(data::DataF, val, ::ToCPU)
@inbounds data[] = val
return data
end
Expand Down Expand Up @@ -746,7 +750,7 @@ end
function Base.size(data::IJF{S, Nij}) where {S, Nij}
return (Nij, Nij, 1, 1, 1)
end
function Base.fill!(data::IJF{S, Nij}, val) where {S, Nij}
function Base.fill!(data::IJF{S, Nij}, val, ::ToCPU) where {S, Nij}
@inbounds for j in 1:Nij, i in 1:Nij
data[i, j] = val
end
Expand Down Expand Up @@ -884,7 +888,7 @@ function replace_basetype(data::IF{S, Ni}, ::Type{T}) where {S, Ni, T}
return IF{S′, Ni}(similar(array, T))
end

function Base.fill!(data::IF{S, Ni}, val) where {S, Ni}
function Base.fill!(data::IF{S, Ni}, val, ::ToCPU) where {S, Ni}
@inbounds for i in 1:Ni
data[i] = val
end
Expand Down Expand Up @@ -998,7 +1002,7 @@ Base.size(data::VF{S, Nv}) where {S, Nv} = (1, 1, 1, Nv, 1)

nlevels(::VF{S, Nv}) where {S, Nv} = Nv

function Base.fill!(data::VF, val)
function Base.fill!(data::VF, val, ::ToCPU)
Nv = nlevels(data)
@inbounds for v in 1:Nv
data[v] = val
Expand Down Expand Up @@ -1123,7 +1127,7 @@ function Base.length(data::VIJFH)
size(parent(data), 1) * size(parent(data), 5)
end

function Base.fill!(data::VIJFH, val)
function Base.fill!(data::VIJFH, val, ::ToCPU)
(Ni, Nj, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
Expand Down Expand Up @@ -1290,7 +1294,7 @@ end
function Base.length(data::VIFH)
nlevels(data) * size(parent(data), 4)
end
function Base.fill!(data::VIFH, val)
function Base.fill!(data::VIFH, val, ::ToCPU)
(Ni, _, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
Expand Down Expand Up @@ -1610,4 +1614,22 @@ array2data(
::VIJFH{<:Any, Nv, Nij},
) where {T, Nv, Nij} = VIJFH{T, Nv, Nij}(reshape(array, Nv, Nij, Nij, 1, :))

"""
device_dispatch(data::AbstractData)
Returns an `Array` or a `CUDA.CuArray` depending
on how `data` is backed.
"""
device_dispatch(dest::AbstractData) = _device_dispatch(dest)

_device_dispatch(x::Array) = ToCPU()
_device_dispatch(x::SubArray) = _device_dispatch(parent(x))
_device_dispatch(x::Base.ReshapedArray) = _device_dispatch(parent(x))
_device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
_device_dispatch(x::SArray) = ToCPU()
_device_dispatch(x::MArray) = ToCPU()

Base.fill!(dest::AbstractData, val) =
Base.fill!(dest, val, device_dispatch(dest))

end # module
Loading

0 comments on commit 988ca7f

Please sign in to comment.