Skip to content

Commit

Permalink
Add linear index support for pointwise kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Aug 11, 2024
1 parent ee2b83e commit 9d5f538
Show file tree
Hide file tree
Showing 15 changed files with 750 additions and 151 deletions.
13 changes: 13 additions & 0 deletions ext/cuda/data_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,16 @@ function Adapt.adapt_structure(
end,
)
end

import Adapt
import CUDA
function Adapt.adapt_structure(
to::CUDA.KernelAdaptor,
bc::DataLayouts.NonExtrudedBroadcasted{Style},
) where {Style}
DataLayouts.NonExtrudedBroadcasted{Style}(
adapt_f(to, bc.f),
Adapt.adapt(to, bc.args),
Adapt.adapt(to, bc.axes),
)
end
113 changes: 23 additions & 90 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
@@ -1,90 +1,9 @@
import ClimaCore.DataLayouts:
to_non_extruded_broadcasted, has_uniform_datalayouts
DataLayouts._device_dispatch(x::CUDA.CuArray) = ToCUDA()

function knl_copyto!(dest, src)

i = CUDA.threadIdx().x
j = CUDA.threadIdx().y

h = CUDA.blockIdx().x
v = CUDA.blockDim().z * (CUDA.blockIdx().y - 1) + CUDA.threadIdx().z

if v <= size(dest, 4)
I = CartesianIndex((i, j, 1, v, h))
@inbounds dest[I] = src[I]
end
return nothing
end

function Base.copyto!(
dest::IJFH{S, Nij, Nh},
bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh},
::ToCUDA,
) where {S, Nij, Nh}
if Nh > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (Nij, Nij),
blocks_s = (Nh, 1),
)
end
return dest
end

function Base.copyto!(
dest::VIJFH{S, Nv, Nij, Nh},
bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh},
::ToCUDA,
) where {S, Nv, Nij, Nh}
if Nv > 0 && Nh > 0
Nv_per_block = min(Nv, fld(256, Nij * Nij))
Nv_blocks = cld(Nv, Nv_per_block)
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (Nij, Nij, Nv_per_block),
blocks_s = (Nh, Nv_blocks),
)
end
return dest
end

function Base.copyto!(
dest::VF{S, Nv},
bc::DataLayouts.BroadcastedUnionVF{S, Nv},
::ToCUDA,
) where {S, Nv}
if Nv > 0
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (1, 1),
blocks_s = (1, Nv),
)
end
return dest
end

function Base.copyto!(
dest::DataF{S},
bc::DataLayouts.BroadcastedUnionDataF{S},
::ToCUDA,
) where {S}
auto_launch!(
knl_copyto!,
(dest, bc),
dest;
threads_s = (1, 1),
blocks_s = (1, 1),
)
return dest
end

import ClimaCore.DataLayouts: isascalar
function knl_copyto_flat!(dest::AbstractData, bc, us)
function knl_copyto_cart!(dest::AbstractData, bc, us)
@inbounds begin
tidx = thread_index()
if tidx get_N(us)
Expand All @@ -96,24 +15,38 @@ function knl_copyto_flat!(dest::AbstractData, bc, us)
return nothing
end

function knl_copyto_linear!(dest::AbstractData, bc, us)
@inbounds begin
tidx = thread_index()
if tidx get_N(us)
dest[tidx] = bc[tidx]
end
end
return nothing
end

function cuda_copyto!(dest::AbstractData, bc)
(_, _, Nv, Nh) = DataLayouts.universal_size(dest)
(Nv > 0 && Nh > 0) || return dest
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
if has_uniform_datalayouts(bc)
bc′ = to_non_extruded_broadcasted(bc)
auto_launch!(knl_copyto_linear!, (dest, bc′, us), dest; auto = true)
else
auto_launch!(knl_copyto_cart!, (dest, bc, us), dest; auto = true)
end
return dest
end

# TODO: can we use CUDA's luanch configuration for all data layouts?
# Currently, it seems to have a slight performance degradation.
#! format: off
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IJF{S, Nij}, bc::DataLayouts.BroadcastedUnionIJF{S, Nij}, ::ToCUDA) where {S, Nij} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IF{S, Ni}, bc::DataLayouts.BroadcastedUnionIF{S, Ni}, ::ToCUDA) where {S, Ni} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VIFH{S, Nv, Ni, Nh}, bc::DataLayouts.BroadcastedUnionVIFH{S, Nv, Ni, Nh}, ::ToCUDA) where {S, Nv, Ni, Nh} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
# Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VIJFH{S, Nv, Nij, Nh}, bc::DataLayouts.BroadcastedUnionVIJFH{S, Nv, Nij, Nh}, ::ToCUDA) where {S, Nv, Nij, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::VF{S, Nv}, bc::DataLayouts.BroadcastedUnionVF{S, Nv}, ::ToCUDA) where {S, Nv} = cuda_copyto!(dest, bc)
Base.copyto!(dest::DataF{S}, bc::DataLayouts.BroadcastedUnionDataF{S}, ::ToCUDA) where {S} = cuda_copyto!(dest, bc)
#! format: on
4 changes: 1 addition & 3 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ function knl_fill_flat!(dest::AbstractData, val, us)
@inbounds begin
tidx = thread_index()
if tidx get_N(us)
n = size(dest)
I = kernel_indexes(tidx, n)
@inbounds dest[I] = val
@inbounds dest[tidx] = val
end
end
return nothing
Expand Down
33 changes: 33 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,37 @@ get_Nij(::IF{S, Nij}) where {S, Nij} = Nij
@inline field_dim(::VIJFH) = 4
@inline field_dim(::VIFH) = 3

# Returns the size of the backing array.
@inline array_size(::IJKFVH{S, Nij, Nk, Nv, Nh}) where {S, Nij, Nk, Nv, Nh} =
(Nij, Nij, Nk, 1, Nv, Nh)
@inline array_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} = (Nij, Nij, 1, Nh)
@inline array_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, Nh)
@inline array_size(::DataF{S}) where {S} = (1,)
@inline array_size(::IJF{S, Nij}) where {S, Nij} = (Nij, Nij, 1)
@inline array_size(::IF{S, Ni}) where {S, Ni} = (Ni, 1)
@inline array_size(::VF{S, Nv}) where {S, Nv} = (Nv, 1)
@inline array_size(::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} =
(Nv, Nij, Nij, 1, Nh)
@inline array_size(::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} =
(Nv, Ni, 1, Nh)

@inline farray_size(
data::IJKFVH{S, Nij, Nk, Nv, Nh},
) where {S, Nij, Nk, Nv, Nh} = (Nij, Nij, Nk, ncomponents(data), Nv, Nh)
@inline farray_size(data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} =
(Nij, Nij, ncomponents(data), Nh)
@inline farray_size(data::IFH{S, Ni, Nh}) where {S, Ni, Nh} =
(Ni, ncomponents(data), Nh)
@inline farray_size(data::DataF{S}) where {S} = (ncomponents(data),)
@inline farray_size(data::IJF{S, Nij}) where {S, Nij} =
(Nij, Nij, ncomponents(data))
@inline farray_size(data::IF{S, Ni}) where {S, Ni} = (Ni, ncomponents(data))
@inline farray_size(data::VF{S, Nv}) where {S, Nv} = (Nv, ncomponents(data))
@inline farray_size(data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} =
(Nv, Nij, Nij, ncomponents(data), Nh)
@inline farray_size(data::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} =
(Nv, Ni, ncomponents(data), Nh)

#! format: off
@inline to_data_specific(::IJFH, I::CartesianIndex) = CartesianIndex(I.I[1], I.I[2], 1, I.I[5])
@inline to_data_specific(::IFH, I::CartesianIndex) = CartesianIndex(I.I[1], 1, I.I[5])
Expand Down Expand Up @@ -1600,10 +1631,12 @@ _device_dispatch(x::AbstractData) = _device_dispatch(parent(x))
_device_dispatch(x::SArray) = ToCPU()
_device_dispatch(x::MArray) = ToCPU()

include("non_extruded_broadcasted.jl")
include("copyto.jl")
include("fused_copyto.jl")
include("fill.jl")
include("mapreduce.jl")
include("has_uniform_datalayouts.jl")

slab_index(i, j) = CartesianIndex(i, j, 1, 1, 1)
slab_index(i) = CartesianIndex(i, 1, 1, 1, 1)
Expand Down
1 change: 1 addition & 0 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ DataSlab2DStyle(::Type{VIJFHStyle{Nv, Nij, Nh, A}}) where {Nv, Nij, Nh, A} =
#####

#! format: off
const BroadcastedUnionData = Union{Base.Broadcast.Broadcasted{<:DataStyle}, AbstractData}
const BroadcastedUnionIJFH{S, Nij, Nh, A} = Union{Base.Broadcast.Broadcasted{IJFHStyle{Nij, Nh, A}}, IJFH{S, Nij, Nh, A}}
const BroadcastedUnionIFH{S, Ni, Nh, A} = Union{Base.Broadcast.Broadcasted{IFHStyle{Ni, Nh, A}}, IFH{S, Ni, Nh, A}}
const BroadcastedUnionIJF{S, Nij, A} = Union{Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}, IJF{S, Nij, A}}
Expand Down
18 changes: 15 additions & 3 deletions src/DataLayouts/copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,22 @@
##### Dispatching and edge cases
#####

Base.copyto!(
dest::AbstractData,
function Base.copyto!(
dest::AbstractData{S},
bc::Union{AbstractData, Base.Broadcast.Broadcasted},
) = Base.copyto!(dest, bc, device_dispatch(dest))
) where {S}
dev = device_dispatch(dest)
if dev isa ToCPU && has_uniform_datalayouts(bc)
# Specialize on linear indexing case:
bc′ = Base.Broadcast.instantiate(to_non_extruded_broadcasted(bc))
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
dest[I] = convert(S, bc′[I])
end
else
Base.copyto!(dest, bc, device_dispatch(dest))
end
return dest
end

# Specialize on non-Broadcasted objects
function Base.copyto!(dest::D, src::D) where {D <: AbstractData}
Expand Down
61 changes: 7 additions & 54 deletions src/DataLayouts/fill.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,13 @@
function Base.fill!(data::IJFH, val, ::ToCPU)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
function Base.fill!(dest::AbstractData, val, ::ToCPU)
@inbounds @simd for I in 1:get_N(UniversalSize(dest))
dest[I] = val
end
return data
return dest
end

function Base.fill!(data::IFH, val, ::ToCPU)
(_, _, _, _, Nh) = size(data)
@inbounds for h in 1:Nh
fill!(slab(data, h), val)
end
return data
end

function Base.fill!(data::DataF, val, ::ToCPU)
@inbounds data[] = val
return data
end

function Base.fill!(data::IJF{S, Nij}, val, ::ToCPU) where {S, Nij}
@inbounds for j in 1:Nij, i in 1:Nij
data[CartesianIndex(i, j, 1, 1, 1)] = val
end
return data
end

function Base.fill!(data::IF{S, Ni}, val, ::ToCPU) where {S, Ni}
@inbounds for i in 1:Ni
data[CartesianIndex(i, 1, 1, 1, 1)] = val
end
return data
end

function Base.fill!(data::VF, val, ::ToCPU)
Nv = nlevels(data)
@inbounds for v in 1:Nv
data[CartesianIndex(1, 1, 1, v, 1)] = val
end
return data
end

function Base.fill!(data::VIJFH, val, ::ToCPU)
(Ni, Nj, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
end
return data
end

function Base.fill!(data::VIFH, val, ::ToCPU)
(Ni, _, _, Nv, Nh) = size(data)
@inbounds for h in 1:Nh, v in 1:Nv
fill!(slab(data, v, h), val)
end
return data
function Base.fill!(dest::DataF, val, ::ToCPU)
@inbounds dest[] = val
return dest
end

Base.fill!(dest::AbstractData, val) =
Expand Down
60 changes: 60 additions & 0 deletions src/DataLayouts/has_uniform_datalayouts.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
@inline function first_datalayout_in_bc(args::Tuple, rargs...)
x1 = first_datalayout_in_bc(args[1], rargs...)
x1 isa AbstractData && return x1
return first_datalayout_in_bc(Base.tail(args), rargs...)
end

@inline first_datalayout_in_bc(args::Tuple{Any}, rargs...) =
first_datalayout_in_bc(args[1], rargs...)
@inline first_datalayout_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_datalayout_in_bc(x) = nothing
@inline first_datalayout_in_bc(x::AbstractData) = x

@inline first_datalayout_in_bc(bc::Base.Broadcast.Broadcasted) =
first_datalayout_in_bc(bc.args)

@inline _has_uniform_datalayouts_args(truesofar, start, args::Tuple, rargs...) =
truesofar &&
_has_uniform_datalayouts(truesofar, start, args[1], rargs...) &&
_has_uniform_datalayouts_args(truesofar, start, Base.tail(args), rargs...)

@inline _has_uniform_datalayouts_args(
truesofar,
start,
args::Tuple{Any},
rargs...,
) = truesofar && _has_uniform_datalayouts(truesofar, start, args[1], rargs...)
@inline _has_uniform_datalayouts_args(truesofar, _, args::Tuple{}, rargs...) =
truesofar

@inline function _has_uniform_datalayouts(
truesofar,
start,
bc::Base.Broadcast.Broadcasted,
)
return truesofar && _has_uniform_datalayouts_args(truesofar, start, bc.args)
end
for DL in (:IJKFVH, :IJFH, :IFH, :DataF, :IJF, :IF, :VF, :VIJFH, :VIFH)
@eval begin
@inline _has_uniform_datalayouts(truesofar, ::$(DL), ::$(DL)) = true
end
end
@inline _has_uniform_datalayouts(truesofar, _, x::AbstractData) = false
@inline _has_uniform_datalayouts(truesofar, _, x) = truesofar

"""
has_uniform_datalayouts
Find the first datalayout in the broadcast expression (BCE),
and compares against every other datalayout in the BCE. Returns
- `true` if the broadcasted object has only a single kind of datalayout (e.g. VF,VF, VIJFH,VIJFH)
- `false` if the broadcasted object has multiple kinds of datalayouts (e.g. VIJFH, VIFH)
Note: a broadcasted object can have different _types_,
e.g., `VIFJH{Float64}` and `VIFJH{Tuple{Float64,Float64}}`
but not different kinds, e.g., `VIFJH{Float64}` and `VF{Float64}`.
"""
function has_uniform_datalayouts end

@inline has_uniform_datalayouts(bc::Base.Broadcast.Broadcasted) =
_has_uniform_datalayouts_args(true, first_datalayout_in_bc(bc), bc.args)

@inline has_uniform_datalayouts(bc::AbstractData) = true
Loading

0 comments on commit 9d5f538

Please sign in to comment.