Skip to content

Commit

Permalink
Fix dispatching for some cuda kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jun 6, 2024
1 parent e7ee513 commit 7cce12f
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 64 deletions.
2 changes: 2 additions & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import ClimaCore.Utilities: half
import ClimaCore.RecursiveApply:
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax

const cu_array = Union{CUDA.CuArray, SubArray{<:Any, <:Any, CUDA.CuArray}}

include(joinpath("cuda", "cuda_utils.jl"))
include(joinpath("cuda", "data_layouts.jl"))
include(joinpath("cuda", "fields.jl"))
Expand Down
36 changes: 11 additions & 25 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@ import ClimaCore.DataLayouts: IJKFVH, IJFH, VIJFH, VIFH, IFH, IJF, IF, VF, DataF
import ClimaCore.DataLayouts: IJFHStyle, VIJFHStyle, VFStyle, DataFStyle
import ClimaCore.DataLayouts: promote_parent_array_type
import ClimaCore.DataLayouts: parent_array_type
import ClimaCore.DataLayouts: device_from_array_type, isascalar
import ClimaCore.DataLayouts: isascalar
import ClimaCore.DataLayouts: fused_copyto!
import Adapt
import CUDA

device_from_array_type(::Type{<:CUDA.CuArray}) = ClimaComms.CUDADevice()
device_from_array_type(::Type{<:SubArray{<:Any, <:Any, <:CUDA.CuArray}}) =
ClimaComms.CUDADevice()

parent_array_type(::Type{<:CUDA.CuArray{T, N, B} where {N}}) where {T, B} =
CUDA.CuArray{T, N, B} where {N}

Expand Down Expand Up @@ -61,7 +57,7 @@ end
function Base.copyto!(
dest::IJFH{S, Nij},
bc::Union{IJFH{S, Nij, A}, Base.Broadcast.Broadcasted{IJFHStyle{Nij, A}}},
) where {S, Nij, A <: CUDA.CuArray}
) where {S, Nij, A <: cu_array}
_, _, _, _, Nh = size(bc)
if Nh > 0
auto_launch!(
Expand All @@ -74,14 +70,7 @@ function Base.copyto!(
end
return dest
end
function Base.fill!(
dest::IJFH{S, Nij, A},
val,
) where {
S,
Nij,
A <: Union{CUDA.CuArray, SubArray{<:Any, <:Any, <:CUDA.CuArray}},
}
function Base.fill!(dest::IJFH{S, Nij, <:cu_array}, val) where {S, Nij}
_, _, _, _, Nh = size(dest)
if Nh > 0
auto_launch!(
Expand All @@ -103,7 +92,7 @@ function Base.copyto!(
VIJFH{S, Nv, Nij, A},
Base.Broadcast.Broadcasted{VIJFHStyle{Nv, Nij, A}},
},
) where {S, Nv, Nij, A <: CUDA.CuArray}
) where {S, Nv, Nij, A <: cu_array}
_, _, _, _, Nh = size(bc)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Expand All @@ -118,10 +107,7 @@ function Base.copyto!(
end
return dest
end
function Base.fill!(
dest::VIJFH{S, Nv, Nij, A},
val,
) where {S, Nv, Nij, A <: CUDA.CuArray}
function Base.fill!(dest::VIJFH{S, Nv, Nij, <:cu_array}, val) where {S, Nv, Nij}
_, _, _, _, Nh = size(dest)
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Expand All @@ -141,7 +127,7 @@ end
function Base.copyto!(
dest::VF{S, Nv},
bc::Union{VF{S, Nv, A}, Base.Broadcast.Broadcasted{VFStyle{Nv, A}}},
) where {S, Nv, A <: CUDA.CuArray}
) where {S, Nv, A <: cu_array}
_, _, _, _, Nh = size(dest)
if Nv > 0 && Nh > 0
auto_launch!(
Expand All @@ -154,7 +140,7 @@ function Base.copyto!(
end
return dest
end
function Base.fill!(dest::VF{S, Nv, A}, val) where {S, Nv, A <: CUDA.CuArray}
function Base.fill!(dest::VF{S, Nv, <:cu_array}, val) where {S, Nv}
_, _, _, _, Nh = size(dest)
if Nv > 0 && Nh > 0
auto_launch!(
Expand Down Expand Up @@ -236,7 +222,7 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S, Nv, Nij},
dest1::VIJFH{S, Nv, Nij, <:cu_array},
::ClimaComms.CUDADevice,
) where {S, Nv, Nij}
_, _, _, _, Nh = size(dest1)
Expand All @@ -256,7 +242,7 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::IJFH{S, Nij},
dest1::IJFH{S, Nij, <:cu_array},
::ClimaComms.CUDADevice,
) where {S, Nij}
_, _, _, _, Nh = size(dest1)
Expand All @@ -273,7 +259,7 @@ function fused_copyto!(
end
function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VF{S, Nv},
dest1::VF{S, Nv, <:cu_array},
::ClimaComms.CUDADevice,
) where {S, Nv}
_, _, _, _, Nh = size(dest1)
Expand All @@ -291,7 +277,7 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::DataF{S},
dest1::DataF{S, <:cu_array},
::ClimaComms.CUDADevice,
) where {S}
auto_launch!(
Expand Down
6 changes: 0 additions & 6 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1475,12 +1475,6 @@ Adapt.adapt_structure(to, data::VF{S, Nv}) where {S, Nv} =
Adapt.adapt_structure(to, data::DataF{S}) where {S} =
DataF{S}(Adapt.adapt(to, parent(data)))

# TODO: Should the DataLayout be device-aware? So that we can
# determine if we're multi-threaded or not?
# This is only currently used in FusedMultiBroadcast kernels
device_from_array_type(::Type{<:AbstractArray}) = ClimaComms.CPUSingleThreaded()
ClimaComms.device(data::AbstractData) =
device_from_array_type(typeof(parent(data)))
empty_kernel_stats(::ClimaComms.AbstractDevice) = nothing
empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())

Expand Down
24 changes: 8 additions & 16 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -678,13 +678,12 @@ function Base.copyto!(
end,
)
# check_fused_broadcast_axes(fmbc) # we should already have checked the axes
fused_copyto!(fmb_inst, dest1, ClimaComms.device(dest1))
fused_copyto!(fmb_inst, dest1)
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIJFH{S1, Nv1, Nij},
::ClimaComms.AbstractCPUDevice,
) where {S1, Nv1, Nij}
_, _, _, _, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
Expand All @@ -700,9 +699,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::IJFH{S, Nij, A},
::ClimaComms.AbstractCPUDevice,
) where {S, Nij, A}
dest1::IJFH{S, Nij},
) where {S, Nij}
# copy contiguous columns
_, _, _, Nv, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
Expand All @@ -717,9 +715,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VIFH{S, Nv1, Ni, A},
::ClimaComms.AbstractCPUDevice,
) where {S, Nv1, Ni, A}
dest1::VIFH{S, Nv1, Ni},
) where {S, Nv1, Ni}
# copy contiguous columns
_, _, _, _, Nh = size(dest1)
for (dest, bc) in fmbc.pairs
Expand All @@ -734,9 +731,8 @@ end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest1::VF{S1, Nv1, A},
::ClimaComms.AbstractCPUDevice,
) where {S1, Nv1, A}
dest1::VF{S1, Nv1},
) where {S1, Nv1}
for (dest, bc) in fmbc.pairs
@inbounds for v in 1:Nv1
I = CartesianIndex(1, 1, 1, v, 1)
Expand All @@ -747,11 +743,7 @@ function fused_copyto!(
return nothing
end

function fused_copyto!(
fmbc::FusedMultiBroadcast,
dest::DataF{S},
::ClimaComms.AbstractCPUDevice,
) where {S}
function fused_copyto!(fmbc::FusedMultiBroadcast, dest::DataF{S}) where {S}
for (dest, bc) in fmbc.pairs
@inbounds dest[] = convert(S, bc[])
end
Expand Down
35 changes: 18 additions & 17 deletions test/Fields/field_multi_broadcast_fusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ function CenterExtrudedFiniteDifferenceSpaceLineHSpace(
return Spaces.ExtrudedFiniteDifferenceSpace(hspace, vspace)
end

function benchmark_kernel!(f!, X, Y)
function benchmark_kernel!(f!, X, Y, device)
println("\n--------------------------- $(nameof(typeof(f!))) ")
trial = benchmark_kernel!(f!, X, Y, ClimaComms.device(X.x1))
trial = benchmark_kernel!(f!, X, Y, device)
show(stdout, MIME("text/plain"), trial)
end
benchmark_kernel!(f!, X, Y, ::ClimaComms.CUDADevice) =
Expand Down Expand Up @@ -250,11 +250,12 @@ end

@testset "FusedMultiBroadcast VIJFH and VF" begin
FT = Float64
device = ClimaComms.device()
space = TU.CenterExtrudedFiniteDifferenceSpace(
FT;
zelem = 3,
helem = 4,
context = ClimaComms.context(),
context = ClimaComms.context(device),
)
X = Fields.FieldVector(
x1 = rand_field(FT, space),
Expand All @@ -269,11 +270,11 @@ end
test_kernel!(; fused!, unfused!, X, Y)
test_kernel!(; fused! = fused_bycolumn!, unfused! = unfused_bycolumn!, X, Y)

benchmark_kernel!(unfused!, X, Y)
benchmark_kernel!(fused!, X, Y)
benchmark_kernel!(unfused!, X, Y, device)
benchmark_kernel!(fused!, X, Y, device)

benchmark_kernel!(unfused_bycolumn!, X, Y)
benchmark_kernel!(fused_bycolumn!, X, Y)
benchmark_kernel!(unfused_bycolumn!, X, Y, device)
benchmark_kernel!(fused_bycolumn!, X, Y, device)
nothing
end

Expand Down Expand Up @@ -306,11 +307,11 @@ end
Y,
)

benchmark_kernel!(unfused!, X, Y)
benchmark_kernel!(fused!, X, Y)
benchmark_kernel!(unfused!, X, Y, device)
benchmark_kernel!(fused!, X, Y, device)

benchmark_kernel!(unfused_bycolumn!, X, Y)
benchmark_kernel!(fused_bycolumn!, X, Y)
benchmark_kernel!(unfused_bycolumn!, X, Y, device)
benchmark_kernel!(fused_bycolumn!, X, Y, device)
nothing
end
end
Expand All @@ -332,8 +333,8 @@ end
y3 = IJFH_data(),
)
test_kernel!(; fused!, unfused!, X, Y)
benchmark_kernel!(unfused!, X, Y)
benchmark_kernel!(fused!, X, Y)
benchmark_kernel!(unfused!, X, Y, device)
benchmark_kernel!(fused!, X, Y, device)
nothing
end

Expand All @@ -350,8 +351,8 @@ end
X = Fields.FieldVector(; x1 = VF_data(), x2 = VF_data(), x3 = VF_data())
Y = Fields.FieldVector(; y1 = VF_data(), y2 = VF_data(), y3 = VF_data())
test_kernel!(; fused!, unfused!, X, Y)
benchmark_kernel!(unfused!, X, Y)
benchmark_kernel!(fused!, X, Y)
benchmark_kernel!(unfused!, X, Y, device)
benchmark_kernel!(fused!, X, Y, device)
nothing
end

Expand All @@ -371,7 +372,7 @@ end
y3 = DataF_data(),
)
test_kernel!(; fused!, unfused!, X, Y)
benchmark_kernel!(unfused!, X, Y)
benchmark_kernel!(fused!, X, Y)
benchmark_kernel!(unfused!, X, Y, device)
benchmark_kernel!(fused!, X, Y, device)
nothing
end

0 comments on commit 7cce12f

Please sign in to comment.