Skip to content

Commit

Permalink
Merge pull request #2041 from CliMA/ck/latency3
Browse files Browse the repository at this point in the history
Add singleton objects to reduce specialization
  • Loading branch information
charleskawczynski authored Oct 16, 2024
2 parents 2ab2a0b + 1c142a3 commit 90572d1
Showing 1 changed file with 120 additions and 76 deletions.
196 changes: 120 additions & 76 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ Returns `Nh`.
"""
@inline get_Nh(us::UniversalSize{Ni, Nj, Nv}) where {Ni, Nj, Nv} = us.Nh

@inline get_Nh_dynamic(data::AbstractData) = size(parent(data), h_dim(data))
# TODO: inline so we don't overspecialize on these helpers
@inline get_Nh_dynamic(data::AbstractData) =
size(parent(data), h_dim(singleton(data)))
@inline get_Nh(data::AbstractData) = get_Nh(UniversalSize(data))
@inline get_Nij(data::AbstractData) = get_Nij(UniversalSize(data))
@inline get_Nv(data::AbstractData) = get_Nv(UniversalSize(data))
Expand Down Expand Up @@ -269,12 +271,12 @@ Base.@propagate_inbounds function Base.getproperty(
SS = fieldtype(S, i)
offset = fieldtypeoffset(T, S, i)
nbytes = typesize(T, SS)
fdim = field_dim(data)
fdim = field_dim(singleton(data))
Ipre = ntuple(i -> Colon(), Val(fdim - 1))
Ipost = ntuple(i -> Colon(), Val(ndims(data) - fdim))
dataview =
@inbounds view(array, Ipre..., (offset + 1):(offset + nbytes), Ipost...)
union_all(data){SS, Base.tail(type_params(data))...}(dataview)
union_all(singleton(data)){SS, Base.tail(type_params(data))...}(dataview)
end

@noinline _property_view(
Expand All @@ -294,19 +296,19 @@ Base.@propagate_inbounds @generated function _property_view(
T = eltype(parent_array_type(AD))
offset = fieldtypeoffset(T, S, Val(Idx))
nbytes = typesize(T, SS)
fdim = field_dim(AD)
fdim = field_dim(AD.name.wrapper)
Ipre = ntuple(i -> Colon(), Val(fdim - 1))
Ipost = ntuple(i -> Colon(), Val(ndims(data) - fdim))
field_byterange = (offset + 1):(offset + nbytes)
return :($(union_all(AD)){$SS, $(Base.tail(type_params(AD)))...}(
return :($(AD.name.wrapper){$SS, $(Base.tail(type_params(AD)))...}(
@inbounds view(parent(data), $Ipre..., $field_byterange, $Ipost...)
))
end

function replace_basetype(data::AbstractData{S}, ::Type{T}) where {S, T}
array = parent(data)
S′ = replace_basetype(eltype(array), T, S)
return union_all(data){S′, Base.tail(type_params(data))...}(
return union_all(singleton(data)){S′, Base.tail(type_params(data))...}(
similar(array, T),
)
end
Expand All @@ -318,6 +320,23 @@ function maybe_populate!(array, ::typeof(rand))
parent(array) .= typeof(array)(rand(eltype(array), size(array)))
end

# ================== Singletons

# These types mirror datalayouts, which
# we use to help reduce overspecialization
abstract type AbstractDataSingleton end
struct IJKFVHSingleton <: AbstractDataSingleton end
struct IJFHSingleton <: AbstractDataSingleton end
struct IFHSingleton <: AbstractDataSingleton end
struct DataFSingleton <: AbstractDataSingleton end
struct IJFSingleton <: AbstractDataSingleton end
struct IFSingleton <: AbstractDataSingleton end
struct VFSingleton <: AbstractDataSingleton end
struct VIJFHSingleton <: AbstractDataSingleton end
struct VIFHSingleton <: AbstractDataSingleton end
struct IH1JH2Singleton <: AbstractDataSingleton end
struct IV1JH2Singleton <: AbstractDataSingleton end

# ==================
# Data3D DataLayout
# ==================
Expand Down Expand Up @@ -612,7 +631,7 @@ Base.@propagate_inbounds function Base.getindex(data::DataF{S}) where {S}
@inbounds get_struct(
parent(data),
S,
Val(field_dim(data)),
Val(field_dim(singleton(data))),
CartesianIndex(1),
)
end
Expand All @@ -625,7 +644,7 @@ Base.@propagate_inbounds function Base.setindex!(data::DataF{S}, val) where {S}
@inbounds set_struct!(
parent(data),
convert(S, val),
Val(field_dim(data)),
Val(field_dim(singleton(data))),
CartesianIndex(1),
)
end
Expand Down Expand Up @@ -1205,16 +1224,18 @@ rebuild(data::AbstractData, ::Type{DA}) where {DA} =
rebuild(data, DA(getfield(data, :array)))

Base.copy(data::AbstractData) =
union_all(data){type_params(data)...}(copy(parent(data)))
union_all(singleton(data)){type_params(data)...}(copy(parent(data)))

# broadcast machinery
include("broadcast.jl")

Adapt.adapt_structure(to, data::AbstractData{S}) where {S} =
union_all(data){type_params(data)...}(Adapt.adapt(to, parent(data)))
union_all(singleton(data)){type_params(data)...}(
Adapt.adapt(to, parent(data)),
)

rebuild(data::AbstractData, array::AbstractArray) =
union_all(data){type_params(data)...}(array)
union_all(singleton(data)){type_params(data)...}(array)

empty_kernel_stats(::ClimaComms.AbstractDevice) = nothing
empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())
Expand All @@ -1233,8 +1254,7 @@ empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())
@inline get_Nij(::IF{S, Nij}) where {S, Nij} = Nij

"""
field_dim(data::AbstractData)
field_dim(::Type{<:AbstractData})
field_dim(::AbstractDataSingleton)
This is an internal function, please do not use outside of ClimaCore.
Expand All @@ -1244,20 +1264,28 @@ This function is helpful for writing generic
code, when reconstructing new datalayouts with new
type parameters.
"""
@inline field_dim(data::AbstractData) = field_dim(typeof(data))
@inline field_dim(::Type{<:IJKFVH}) = 4
@inline field_dim(::Type{<:IJFH}) = 3
@inline field_dim(::Type{<:IFH}) = 2
@inline field_dim(::Type{<:DataF}) = 1
@inline field_dim(::Type{<:IJF}) = 3
@inline field_dim(::Type{<:IF}) = 2
@inline field_dim(::Type{<:VF}) = 2
@inline field_dim(::Type{<:VIJFH}) = 4
@inline field_dim(::Type{<:VIFH}) = 3

"""
h_dim(data::AbstractData)
h_dim(::Type{<:AbstractData})
@inline field_dim(::IJKFVHSingleton) = 4
@inline field_dim(::IJFHSingleton) = 3
@inline field_dim(::IFHSingleton) = 2
@inline field_dim(::DataFSingleton) = 1
@inline field_dim(::IJFSingleton) = 3
@inline field_dim(::IFSingleton) = 2
@inline field_dim(::VFSingleton) = 2
@inline field_dim(::VIJFHSingleton) = 4
@inline field_dim(::VIFHSingleton) = 3

@inline field_dim(::Type{IJKFVH}) = 4
@inline field_dim(::Type{IJFH}) = 3
@inline field_dim(::Type{IFH}) = 2
@inline field_dim(::Type{DataF}) = 1
@inline field_dim(::Type{IJF}) = 3
@inline field_dim(::Type{IF}) = 2
@inline field_dim(::Type{VF}) = 2
@inline field_dim(::Type{VIJFH}) = 4
@inline field_dim(::Type{VIFH}) = 3

"""
h_dim(::AbstractDataSingleton)
This is an internal function, please do not use outside of ClimaCore.
Expand All @@ -1267,22 +1295,19 @@ This function is helpful for writing generic
code, when reconstructing new datalayouts with new
type parameters.
"""
@inline h_dim(data::AbstractData) = h_dim(typeof(data))
@inline h_dim(::Type{<:IJKFVH}) = 5
@inline h_dim(::Type{<:IJFH}) = 4
@inline h_dim(::Type{<:IFH}) = 3
@inline h_dim(::Type{<:VIJFH}) = 5
@inline h_dim(::Type{<:VIFH}) = 4
@inline h_dim(::IJKFVHSingleton) = 5
@inline h_dim(::IJFHSingleton) = 4
@inline h_dim(::IFHSingleton) = 3
@inline h_dim(::VIJFHSingleton) = 5
@inline h_dim(::VIFHSingleton) = 4

@inline to_data_specific(data::AbstractData, I::CartesianIndex) =
CartesianIndex(_to_data_specific(data, I.I))
@inline _to_data_specific(::VF, I::Tuple) = (I[4], 1)
@inline _to_data_specific(::IF, I::Tuple) = (I[1], 1)
@inline _to_data_specific(::IJF, I::Tuple) = (I[1], I[2], 1)
@inline _to_data_specific(::IJFH, I::Tuple) = (I[1], I[2], 1, I[5])
@inline _to_data_specific(::IFH, I::Tuple) = (I[1], 1, I[5])
@inline _to_data_specific(::VIJFH, I::Tuple) = (I[4], I[1], I[2], 1, I[5])
@inline _to_data_specific(::VIFH, I::Tuple) = (I[4], I[1], 1, I[5])
@inline to_data_specific(::VFSingleton, I::Tuple) = (I[4], 1)
@inline to_data_specific(::IFSingleton, I::Tuple) = (I[1], 1)
@inline to_data_specific(::IJFSingleton, I::Tuple) = (I[1], I[2], 1)
@inline to_data_specific(::IJFHSingleton, I::Tuple) = (I[1], I[2], 1, I[5])
@inline to_data_specific(::IFHSingleton, I::Tuple) = (I[1], 1, I[5])
@inline to_data_specific(::VIJFHSingleton, I::Tuple) = (I[4], I[1], I[2], 1, I[5])
@inline to_data_specific(::VIFHSingleton, I::Tuple) = (I[4], I[1], 1, I[5])

"""
bounds_condition(data::AbstractData, I::Tuple)
Expand Down Expand Up @@ -1323,7 +1348,7 @@ type parameters.

"""
union_all(data::AbstractData)
union_all(::Type{<:AbstractData})
union_all(singleton(::AbstractData))
This is an internal function, please do not use outside of ClimaCore.
Expand All @@ -1334,18 +1359,17 @@ This function is helpful for writing generic
code, when reconstructing new datalayouts with new
type parameters.
"""
@inline union_all(data::AbstractData) = union_all(typeof(data))
@inline union_all(::Type{<:IJKFVH}) = IJKFVH
@inline union_all(::Type{<:IJFH}) = IJFH
@inline union_all(::Type{<:IFH}) = IFH
@inline union_all(::Type{<:DataF}) = DataF
@inline union_all(::Type{<:IJF}) = IJF
@inline union_all(::Type{<:IF}) = IF
@inline union_all(::Type{<:VF}) = VF
@inline union_all(::Type{<:VIJFH}) = VIJFH
@inline union_all(::Type{<:VIFH}) = VIFH
@inline union_all(::Type{<:IH1JH2}) = IH1JH2
@inline union_all(::Type{<:IV1JH2}) = IV1JH2
@inline union_all(::IJKFVHSingleton) = IJKFVH
@inline union_all(::IJFHSingleton) = IJFH
@inline union_all(::IFHSingleton) = IFH
@inline union_all(::DataFSingleton) = DataF
@inline union_all(::IJFSingleton) = IJF
@inline union_all(::IFSingleton) = IF
@inline union_all(::VFSingleton) = VF
@inline union_all(::VIJFHSingleton) = VIJFH
@inline union_all(::VIFHSingleton) = VIFH
@inline union_all(::IH1JH2Singleton) = IH1JH2
@inline union_all(::IV1JH2Singleton) = IV1JH2

"""
array_size(data::AbstractData, [dim])
Expand Down Expand Up @@ -1434,11 +1458,12 @@ Base.ndims(::Type{T}) where {T <: AbstractData} =
I::CartesianIndex,
)
@boundscheck bounds_condition(data, I) || throw(BoundsError(data, I))
s = singleton(data)
@inbounds get_struct(
parent(data),
eltype(data),
Val(field_dim(data)),
to_data_specific(data, I),
Val(field_dim(s)),
CartesianIndex(to_data_specific(s, I.I)),
)
end

Expand All @@ -1448,11 +1473,12 @@ end
I::CartesianIndex,
)
@boundscheck bounds_condition(data, I) || throw(BoundsError(data, I))
s = singleton(data)
@inbounds set_struct!(
parent(data),
convert(eltype(data), val),
Val(field_dim(data)),
to_data_specific(data, I),
Val(field_dim(s)),
CartesianIndex(to_data_specific(s, I.I)),
)
end

Expand All @@ -1463,31 +1489,35 @@ if VERSION ≥ v"1.11.0-beta"
@inline Base.getindex(
data::Union{IJF, IJFH, IFH, VIJFH, VIFH, VF, IF},
I::Vararg{Int, N},
) where {N} = Base.getindex(data, to_universal_index(data, I))
) where {N} = Base.getindex(
data,
CartesianIndex(to_universal_index(singleton(data), I)),
)

@inline Base.setindex!(
data::Union{IJF, IJFH, IFH, VIJFH, VIFH, VF, IF},
val,
I::Vararg{Int, N},
) where {N} = Base.setindex!(data, val, to_universal_index(data, I))

@inline to_universal_index(data::AbstractData, I::Tuple) =
CartesianIndex(_to_universal_index(data, I))
) where {N} = Base.setindex!(
data,
val,
CartesianIndex(to_universal_index(singleton(data), I)),
)

# Certain datalayouts support special indexing.
# Like VF datalayouts with `getindex(::VF, v::Integer)`
#! format: off
@inline _to_universal_index(::VF, I::NTuple{1, T}) where {T} = (T(1), T(1), T(1), I[1], T(1))
@inline _to_universal_index(::IF, I::NTuple{1, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline _to_universal_index(::IF, I::NTuple{2, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline _to_universal_index(::IF, I::NTuple{3, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline _to_universal_index(::IF, I::NTuple{4, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline _to_universal_index(::IF, I::NTuple{5, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline _to_universal_index(::IJF, I::NTuple{2, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
@inline _to_universal_index(::IJF, I::NTuple{3, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
@inline _to_universal_index(::IJF, I::NTuple{4, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
@inline _to_universal_index(::IJF, I::NTuple{5, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
@inline _to_universal_index(::AbstractData, I::NTuple{5}) = I
@inline to_universal_index(::VFSingleton, I::NTuple{1, T}) where {T} = (T(1), T(1), T(1), I[1], T(1))
@inline to_universal_index(::IFSingleton, I::NTuple{1, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline to_universal_index(::IFSingleton, I::NTuple{2, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline to_universal_index(::IFSingleton, I::NTuple{3, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline to_universal_index(::IFSingleton, I::NTuple{4, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline to_universal_index(::IFSingleton, I::NTuple{5, T}) where {T} = (I[1], T(1), T(1), T(1), T(1))
@inline to_universal_index(::IJFSingleton, I::NTuple{2, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
@inline to_universal_index(::IJFSingleton, I::NTuple{3, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
@inline to_universal_index(::IJFSingleton, I::NTuple{4, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
@inline to_universal_index(::IJFSingleton, I::NTuple{5, T}) where {T} = (I[1], I[2], T(1), T(1), T(1))
@inline to_universal_index(::AbstractDataSingleton, I::NTuple{5}) = I
#! format: on
### ---------------
end
Expand Down Expand Up @@ -1519,7 +1549,7 @@ The dimensions of `array` are assumed to be
- `([number of vertical nodes], number of horizontal nodes)`.
"""
array2data(array::AbstractArray{T}, data::AbstractData) where {T} =
union_all(data){T, Base.tail(type_params(data))...}(
union_all(singleton(data)){T, Base.tail(type_params(data))...}(
reshape(array, array_size(data)...),
)

Expand All @@ -1529,13 +1559,27 @@ array2data(array::AbstractArray{T}, data::AbstractData) where {T} =
Returns an `ToCPU` or a `ToCUDA` for CPU
and CUDA-backed arrays accordingly.
"""
device_dispatch(x::AbstractArray) = ToCPU()
device_dispatch(x::Array) = ToCPU()
device_dispatch(x::SubArray) = device_dispatch(parent(x))
device_dispatch(x::Base.ReshapedArray) = device_dispatch(parent(x))
device_dispatch(x::AbstractData) = device_dispatch(parent(x))
device_dispatch(x::SArray) = ToCPU()
device_dispatch(x::MArray) = ToCPU()

@inline singleton(@nospecialize(::IJKFVH)) = IJKFVHSingleton()
@inline singleton(@nospecialize(::IJFH)) = IJFHSingleton()
@inline singleton(@nospecialize(::IFH)) = IFHSingleton()
@inline singleton(@nospecialize(::DataF)) = DataFSingleton()
@inline singleton(@nospecialize(::IJF)) = IJFSingleton()
@inline singleton(@nospecialize(::IF)) = IFSingleton()
@inline singleton(@nospecialize(::VF)) = VFSingleton()
@inline singleton(@nospecialize(::VIJFH)) = VIJFHSingleton()
@inline singleton(@nospecialize(::VIFH)) = VIFHSingleton()
@inline singleton(@nospecialize(::IH1JH2)) = IH1JH2Singleton()
@inline singleton(@nospecialize(::IV1JH2)) = IV1JH2Singleton()


include("copyto.jl")
include("fused_copyto.jl")
include("fill.jl")
Expand Down

0 comments on commit 90572d1

Please sign in to comment.