Skip to content

Commit

Permalink
Reduce more code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Aug 14, 2024
1 parent dc72795 commit 9412925
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 89 deletions.
100 changes: 14 additions & 86 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,14 +361,6 @@ function IJFH{S, Nij, Nh}(array::AbstractArray{T, 4}) where {S, Nij, Nh, T}
IJFH{S, Nij, Nh, typeof(array)}(array)
end

rebuild(
data::IJFH{S, Nij, Nh},
array::A,
) where {S, Nij, Nh, A <: AbstractArray} = IJFH{S, Nij, Nh}(array)

Base.copy(data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} =
IJFH{S, Nij, Nh}(copy(parent(data)))

@inline universal_size(::IJFH{S, Nij, Nh}) where {S, Nij, Nh} =
(Nij, Nij, 1, 1, Nh)

Expand Down Expand Up @@ -477,12 +469,6 @@ function IFH{S, Ni, Nh}(::Type{ArrayType}) where {S, Ni, Nh, ArrayType}
IFH{S, Ni, Nh}(ArrayType(undef, Ni, typesize(T, S), Nh))
end

rebuild(data::IFH{S, Ni, Nh}, array::AbstractArray{T, 3}) where {S, Ni, Nh, T} =
IFH{S, Ni, Nh}(array)

Base.copy(data::IFH{S, Ni, Nh}) where {S, Ni, Nh} =
IFH{S, Ni, Nh}(copy(parent(data)))

@inline universal_size(::IFH{S, Ni, Nh}) where {S, Ni, Nh} = (Ni, 1, 1, 1, Nh)

@inline function slab(data::IFH{S, Ni}, h::Integer) where {S, Ni}
Expand Down Expand Up @@ -538,8 +524,6 @@ struct DataF{S, A} <: Data0D{S}
array::A
end

rebuild(data::DataF{S}, array::AbstractArray) where {S} = DataF{S}(array)

parent_array_type(::Type{DataF{S, A}}) where {S, A} = A

function DataF{S}(array::AbstractVector{T}) where {S, T}
Expand Down Expand Up @@ -597,8 +581,6 @@ end
@inbounds col[] = val
end

Base.copy(data::DataF{S}) where {S} = DataF{S}(copy(parent(data)))

# ======================
# DataSlab2D DataLayout
# ======================
Expand Down Expand Up @@ -640,8 +622,6 @@ function IJF{S, Nij}(array::AbstractArray{T, 3}) where {S, Nij, T}
IJF{S, Nij, typeof(array)}(array)
end

rebuild(data::IJF{S, Nij}, array::A) where {S, Nij, A <: AbstractArray} =
IJF{S, Nij}(array)
function IJF{S, Nij}(::Type{MArray}, ::Type{T}) where {S, Nij, T}
Nf = typesize(T, S)
array = MArray{Tuple{Nij, Nij, Nf}, T, 3, Nij * Nij * Nf}(undef)
Expand Down Expand Up @@ -724,9 +704,6 @@ struct IF{S, Ni, A} <: DataSlab1D{S, Ni}
array::A
end

rebuild(data::IF{S, Nij}, array::A) where {S, Nij, A <: AbstractArray} =
IF{S, Nij, A}(array)

parent_array_type(::Type{IF{S, Ni, A}}) where {S, Ni, A} = A

function IF{S, Ni}(array::AbstractArray{T, 2}) where {S, Ni, T}
Expand Down Expand Up @@ -817,10 +794,6 @@ function VF{S, Nv}(::Type{ArrayType}, nelements) where {S, Nv, ArrayType}
VF{S, Nv}(ArrayType(undef, nelements, typesize(T, S)))
end

rebuild(data::VF{S, Nv}, array::AbstractArray{T, 2}) where {S, Nv, T} =
VF{S, Nv}(array)

Base.copy(data::VF{S, Nv}) where {S, Nv} = VF{S, Nv}(copy(parent(data)))
Base.lastindex(data::VF) = length(data)

nlevels(::VF{S, Nv}) where {S, Nv} = Nv
Expand Down Expand Up @@ -891,23 +864,15 @@ function VIJFH{S, Nv, Nij, Nh}(
array::AbstractArray{T, 5},
) where {S, Nv, Nij, Nh, T}
check_basetype(T, S)
@assert size(array, 1) == Nv
@assert size(array, 2) == size(array, 3) == Nij
@assert size(array, 4) == typesize(T, S)
@assert size(array, 5) == Nh
VIJFH{S, Nv, Nij, Nh, typeof(array)}(array)
end

rebuild(
data::VIJFH{S, Nv, Nij, Nh},
array::AbstractArray{T, 5},
) where {S, Nv, Nij, Nh, T} = VIJFH{S, Nv, Nij, Nh}(array)

nlevels(::VIJFH{S, Nv}) where {S, Nv} = Nv

function Base.copy(data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh}
VIJFH{S, Nv, Nij, Nh}(copy(parent(data)))
end

@inline universal_size(::VIJFH{<:Any, Nv, Nij, Nh}) where {Nv, Nij, Nh} =
(Nij, Nij, 1, Nv, Nh)

Expand Down Expand Up @@ -1015,22 +980,15 @@ function VIFH{S, Nv, Ni, Nh}(
array::AbstractArray{T, 4},
) where {S, Nv, Ni, Nh, T}
check_basetype(T, S)
@assert size(array, 1) == Nv
@assert size(array, 2) == Ni
@assert size(array, 3) == typesize(T, S)
@assert size(array, 4) == Nh
VIFH{S, Nv, Ni, Nh, typeof(array)}(array)
end

rebuild(
data::VIFH{S, Nv, Ni, Nh},
array::A,
) where {S, Nv, Ni, Nh, A <: AbstractArray} = VIFH{S, Nv, Ni, Nh}(array)

nlevels(::VIFH{S, Nv}) where {S, Nv} = Nv

Base.copy(data::VIFH{S, Nv, Ni, Nh}) where {S, Nv, Ni, Nh} =
VIFH{S, Nv, Ni, Nh}(copy(parent(data)))

@inline universal_size(::VIFH{<:Any, Nv, Ni, Nh}) where {Nv, Ni, Nh} =
(Ni, 1, 1, Nv, Nh)

Expand All @@ -1041,7 +999,7 @@ Base.length(data::VIFH) = nlevels(data) * get_Nh(data)
array = parent(data)
@boundscheck (1 <= v <= Nv && 1 <= h <= Nh) ||
throw(BoundsError(data, (v, h)))
Nf = size(array, 3)
Nf = ncomponents(data)
dataview = @inbounds SubArray(
array,
(v, Base.Slice(Base.OneTo(Ni)), Base.Slice(Base.OneTo(Nf)), h),
Expand All @@ -1054,7 +1012,7 @@ end
array = parent(data)
@boundscheck (1 <= i <= Ni && 1 <= h <= Nh) ||
throw(BoundsError(data, (i, h)))
Nf = size(array, 3)
Nf = ncomponents(data)
dataview = @inbounds SubArray(
array,
(Base.Slice(Base.OneTo(Nv)), i, Base.Slice(Base.OneTo(Nf)), h),
Expand All @@ -1071,7 +1029,7 @@ end
array = parent(data)
@boundscheck (1 <= i <= Ni && j == 1 && 1 <= h <= Nh) ||
throw(BoundsError(data, (i, j, h)))
Nf = size(array, 3)
Nf = ncomponents(data)
dataview = @inbounds SubArray(
array,
(Base.Slice(Base.OneTo(Nv)), i, Base.Slice(Base.OneTo(Nf)), h),
Expand Down Expand Up @@ -1135,10 +1093,7 @@ function IH1JH2{S, Nij}(array::AbstractMatrix{S}) where {S, Nij}
IH1JH2{S, Nij, typeof(array)}(array)
end

Base.copy(data::IH1JH2{S, Nij}) where {S, Nij} =
IH1JH2{S, Nij}(copy(parent(data)))

@inline universal_size(::IH1JH2{S, Nij}) where {S, Nij} =
@inline universal_size(data::IH1JH2{S, Nij}) where {S, Nij} =
(Nij, Nij, 1, 1, div(array_length(data), Nij * Nij))

Base.length(data::IH1JH2{S, Nij}) where {S, Nij} =
Expand Down Expand Up @@ -1180,10 +1135,7 @@ function IV1JH2{S, n1, Ni}(array::AbstractMatrix{S}) where {S, n1, Ni}
IV1JH2{S, n1, Ni, typeof(array)}(array)
end

Base.copy(data::IV1JH2{S, n1, Ni}) where {S, n1, Ni} =
IV1JH2{S, n1, Ni}(copy(parent(data)))

@inline universal_size(::IV1JH2{S, n1, Ni}) where {S, n1, Ni} =
@inline universal_size(data::IV1JH2{S, n1, Ni}) where {S, n1, Ni} =
(Ni, 1, 1, n1, div(size(parent(data), 2), Ni))

Base.length(data::IV1JH2{S, n1, Ni}) where {S, n1, Ni} =
Expand Down Expand Up @@ -1214,41 +1166,17 @@ end
rebuild(data::AbstractData, ::Type{DA}) where {DA} =
rebuild(data, DA(getfield(data, :array)))

Base.copy(data::AbstractData) =
union_all(data){type_params(data)...}(copy(parent(data)))

# broadcast machinery
include("broadcast.jl")

Adapt.adapt_structure(to, data::AbstractData{S}) where {S} =
union_all(data){type_params(data)...}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(
to,
data::IJKFVH{S, Nij, Nk, Nv, Nh},
) where {S, Nij, Nk, Nv, Nh} =
IJKFVH{S, Nij, Nk, Nv, Nh}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(to, data::IJFH{S, Nij, Nh}) where {S, Nij, Nh} =
IJFH{S, Nij, Nh}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(to, data::VIJFH{S, Nv, Nij, Nh}) where {S, Nv, Nij, Nh} =
VIJFH{S, Nv, Nij, Nh}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(
to,
data::VIFH{S, Nv, Ni, Nh, A},
) where {S, Nv, Ni, Nh, A} = VIFH{S, Nv, Ni, Nh}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(to, data::IFH{S, Ni, Nh}) where {S, Ni, Nh} =
IFH{S, Ni, Nh}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(to, data::IJF{S, Nij}) where {S, Nij} =
IJF{S, Nij}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(to, data::IF{S, Ni}) where {S, Ni} =
IF{S, Ni}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(to, data::VF{S, Nv}) where {S, Nv} =
VF{S, Nv}(Adapt.adapt(to, parent(data)))

Adapt.adapt_structure(to, data::DataF{S}) where {S} =
DataF{S}(Adapt.adapt(to, parent(data)))
rebuild(data::AbstractData, array::AbstractArray) =
union_all(data){type_params(data)...}(array)

empty_kernel_stats(::ClimaComms.AbstractDevice) = nothing
empty_kernel_stats() = empty_kernel_stats(ClimaComms.device())
Expand Down
6 changes: 3 additions & 3 deletions test/Spaces/opt_spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ end
else
test_n_failures(0, TU.PointSpace, context)
test_n_failures(137, TU.SpectralElementSpace1D, context)
test_n_failures(298, TU.SpectralElementSpace2D, context)
test_n_failures(308, TU.SpectralElementSpace2D, context)
test_n_failures(118, TU.ColumnCenterFiniteDifferenceSpace, context)
test_n_failures(118, TU.ColumnFaceFiniteDifferenceSpace, context)
test_n_failures(304, TU.SphereSpectralElementSpace, context)
test_n_failures(314, TU.SphereSpectralElementSpace, context)
test_n_failures(321, TU.CenterExtrudedFiniteDifferenceSpace, context)
test_n_failures(321, TU.FaceExtrudedFiniteDifferenceSpace, context)

Expand All @@ -60,7 +60,7 @@ end

result = JET.@report_opt Grids._SpectralElementGrid2D(Spaces.topology(space), Spaces.quadrature_style(space); enable_bubble=false)
n_found = length(JET.get_reports(result.analyzer, result.result))
n_allowed = 177
n_allowed = 187
@test n_found n_allowed
if n_found < n_allowed
@info "Inference may have improved for _SpectralElementGrid2D: (n_found, n_allowed) = ($n_found, $n_allowed)"
Expand Down

0 comments on commit 9412925

Please sign in to comment.