Skip to content

Commit

Permalink
Merge pull request #1913 from CliMA/ck/fixup_struct_tests
Browse files Browse the repository at this point in the history
Fixup get_struct unit tests
  • Loading branch information
charleskawczynski authored Aug 8, 2024
2 parents feb7247 + c0f4c9a commit bf4d141
Showing 1 changed file with 70 additions and 196 deletions.
266 changes: 70 additions & 196 deletions test/DataLayouts/unit_struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Revise; include(joinpath("test", "DataLayouts", "unit_struct.jl"))
=#
using Test
using ClimaCore.DataLayouts
using ClimaCore.DataLayouts: get_struct
using StaticArrays

function one_to_n(a::Array)
Expand All @@ -14,49 +15,8 @@ function one_to_n(a::Array)
end
one_to_n(s::Tuple, ::Type{FT}) where {FT} = one_to_n(zeros(FT, s...))
ncomponents(::Type{FT}, ::Type{S}) where {FT, S} = div(sizeof(S), sizeof(FT))

function test_get_struct(::Type{FT}, ::Type{S}) where {FT, S}
s = (2,)
a = one_to_n(s, FT)
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
for (i, ci) in enumerate(CI)
for j in 1:length(s)
@test DataLayouts.get_struct(a, S, Val(j), ci) == FT(i)
end
end

s = (2, 3)
a = one_to_n(s, FT)
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
for (i, ci) in enumerate(CI)
for j in 1:length(s)
@test DataLayouts.get_struct(a, S, Val(j), ci) == FT(i)
end
end

s = (2, 3, 4)
a = one_to_n(s, FT)
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
for (i, ci) in enumerate(CI)
for j in 1:length(s)
@test DataLayouts.get_struct(a, S, Val(j), ci) == FT(i)
end
end

s = (2, 3, 4, 5)
a = one_to_n(s, FT)
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
for (i, ci) in enumerate(CI)
for j in 1:length(s)
@test DataLayouts.get_struct(a, S, Val(j), ci) == FT(i)
end
end
end

@testset "get_struct - Float" begin
test_get_struct(Float64, Float64)
test_get_struct(Float32, Float32)
end
field_dim_to_one(s, dim) = Tuple(map(j-> j == dim ? 1 : s[j], 1:length(s)))
CI(s) = CartesianIndices(map-> Base.OneTo(ξ), s))

struct Foo{T}
x::T
Expand All @@ -65,175 +25,89 @@ end

Base.zero(::Type{Foo{T}}) where {T} = Foo{T}(0, 0)

@testset "get_struct - flat struct 2-fields 1-dim" begin
@testset "get_struct - IFH indexing" begin
FT = Float64
S = Foo{FT}
s = (4,)
a = one_to_n(s, FT)
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
s_array = (3, 2, 4)
@test ncomponents(FT, S) == 2
@test DataLayouts.get_struct(a, S, Val(1), CI[1]) == Foo{FT}(1.0, 2.0)
@test DataLayouts.get_struct(a, S, Val(1), CI[2]) == Foo{FT}(2.0, 3.0)
@test DataLayouts.get_struct(a, S, Val(1), CI[3]) == Foo{FT}(3.0, 4.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[4])
s = field_dim_to_one(s_array, 2)
a = one_to_n(s_array, FT)
@test get_struct(a, S, Val(2), CI(s)[1]) == Foo{FT}(1.0, 4.0)
@test get_struct(a, S, Val(2), CI(s)[2]) == Foo{FT}(2.0, 5.0)
@test get_struct(a, S, Val(2), CI(s)[3]) == Foo{FT}(3.0, 6.0)
@test get_struct(a, S, Val(2), CI(s)[4]) == Foo{FT}(7.0, 10.0)
@test get_struct(a, S, Val(2), CI(s)[5]) == Foo{FT}(8.0, 11.0)
@test get_struct(a, S, Val(2), CI(s)[6]) == Foo{FT}(9.0, 12.0)
@test get_struct(a, S, Val(2), CI(s)[7]) == Foo{FT}(13.0, 16.0)
@test get_struct(a, S, Val(2), CI(s)[8]) == Foo{FT}(14.0, 17.0)
@test get_struct(a, S, Val(2), CI(s)[9]) == Foo{FT}(15.0, 18.0)
@test get_struct(a, S, Val(2), CI(s)[10]) == Foo{FT}(19.0, 22.0)
@test get_struct(a, S, Val(2), CI(s)[11]) == Foo{FT}(20.0, 23.0)
@test get_struct(a, S, Val(2), CI(s)[12]) == Foo{FT}(21.0, 24.0)
@test_throws BoundsError get_struct(a, S, Val(2), CI(s)[13])
end

@testset "get_struct - flat struct 2-fields 3-dims" begin
@testset "get_struct - IJF indexing" begin
FT = Float64
S = Foo{FT}
s = (2, 3, 4)
a = one_to_n(s, FT)
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
s_array = (3, 4, 2)
@test ncomponents(FT, S) == 2
# Call get_struct, and span `a` (access elements to 24.0): 12 cases

# No datalayouts have field dim of 1
@test DataLayouts.get_struct(a, S, Val(1), CI[1]) == Foo{FT}(1.0, 2.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[2])
@test DataLayouts.get_struct(a, S, Val(1), CI[3]) == Foo{FT}(3.0, 4.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[4])
@test DataLayouts.get_struct(a, S, Val(1), CI[5]) == Foo{FT}(5.0, 6.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[6])
@test DataLayouts.get_struct(a, S, Val(1), CI[7]) == Foo{FT}(7.0, 8.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[8])
@test DataLayouts.get_struct(a, S, Val(1), CI[9]) == Foo{FT}(9.0, 10.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[10])
@test DataLayouts.get_struct(a, S, Val(1), CI[11]) == Foo{FT}(11.0, 12.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[12])
@test DataLayouts.get_struct(a, S, Val(1), CI[13]) == Foo{FT}(13.0, 14.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[14])
@test DataLayouts.get_struct(a, S, Val(1), CI[15]) == Foo{FT}(15.0, 16.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[16])
@test DataLayouts.get_struct(a, S, Val(1), CI[17]) == Foo{FT}(17.0, 18.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[18])
@test DataLayouts.get_struct(a, S, Val(1), CI[19]) == Foo{FT}(19.0, 20.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[20])
@test DataLayouts.get_struct(a, S, Val(1), CI[21]) == Foo{FT}(21.0, 22.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[22])
@test DataLayouts.get_struct(a, S, Val(1), CI[23]) == Foo{FT}(23.0, 24.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[24])

# e.g., how IFH is indexed (field dim 2)
@test DataLayouts.get_struct(a, S, Val(2), CI[1]) == Foo{FT}(1.0, 3.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[2]) == Foo{FT}(2.0, 4.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[3]) == Foo{FT}(3.0, 5.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[4]) == Foo{FT}(4.0, 6.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[5])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[6])
@test DataLayouts.get_struct(a, S, Val(2), CI[7]) == Foo{FT}(7.0, 9.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[8]) == Foo{FT}(8.0, 10.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[9]) == Foo{FT}(9.0, 11.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[10]) == Foo{FT}(10.0, 12.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[11])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[12])
@test DataLayouts.get_struct(a, S, Val(2), CI[13]) == Foo{FT}(13.0, 15.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[14]) == Foo{FT}(14.0, 16.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[15]) == Foo{FT}(15.0, 17.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[16]) == Foo{FT}(16.0, 18.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[17])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[18])
@test DataLayouts.get_struct(a, S, Val(2), CI[19]) == Foo{FT}(19.0, 21.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[20]) == Foo{FT}(20.0, 22.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[21]) == Foo{FT}(21.0, 23.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[22]) == Foo{FT}(22.0, 24.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[23])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[24])

# e.g., how IJF is indexed (field dim 3)
@test DataLayouts.get_struct(a, S, Val(3), CI[1]) == Foo{FT}(1.0, 7.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[2]) == Foo{FT}(2.0, 8.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[3]) == Foo{FT}(3.0, 9.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[4]) == Foo{FT}(4.0, 10.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[5]) == Foo{FT}(5.0, 11.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[6]) == Foo{FT}(6.0, 12.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[7]) == Foo{FT}(7.0, 13.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[8]) == Foo{FT}(8.0, 14.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[9]) == Foo{FT}(9.0, 15.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[10]) == Foo{FT}(10.0, 16.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[11]) == Foo{FT}(11.0, 17.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[12]) == Foo{FT}(12.0, 18.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[13]) == Foo{FT}(13.0, 19.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[14]) == Foo{FT}(14.0, 20.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[15]) == Foo{FT}(15.0, 21.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[16]) == Foo{FT}(16.0, 22.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[17]) == Foo{FT}(17.0, 23.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[18]) == Foo{FT}(18.0, 24.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(3), CI[19])
s = field_dim_to_one(s_array, 3)
a = one_to_n(s_array, FT)
@test get_struct(a, S, Val(3), CI(s)[1]) == Foo{FT}(1.0, 13.0)
@test get_struct(a, S, Val(3), CI(s)[2]) == Foo{FT}(2.0, 14.0)
@test get_struct(a, S, Val(3), CI(s)[3]) == Foo{FT}(3.0, 15.0)
@test get_struct(a, S, Val(3), CI(s)[4]) == Foo{FT}(4.0, 16.0)
@test get_struct(a, S, Val(3), CI(s)[5]) == Foo{FT}(5.0, 17.0)
@test get_struct(a, S, Val(3), CI(s)[6]) == Foo{FT}(6.0, 18.0)
@test get_struct(a, S, Val(3), CI(s)[7]) == Foo{FT}(7.0, 19.0)
@test get_struct(a, S, Val(3), CI(s)[8]) == Foo{FT}(8.0, 20.0)
@test get_struct(a, S, Val(3), CI(s)[9]) == Foo{FT}(9.0, 21.0)
@test get_struct(a, S, Val(3), CI(s)[10]) == Foo{FT}(10.0, 22.0)
@test get_struct(a, S, Val(3), CI(s)[11]) == Foo{FT}(11.0, 23.0)
@test get_struct(a, S, Val(3), CI(s)[12]) == Foo{FT}(12.0, 24.0)
@test_throws BoundsError get_struct(a, S, Val(3), CI(s)[13])
end

@testset "get_struct - flat struct 2-fields 5-dims" begin
@testset "get_struct - VIJFH indexing" begin
FT = Float64
S = Foo{FT}
s = (2, 2, 2, 2, 2)
a = one_to_n(s, FT)
CI = CartesianIndices(map-> Base.OneTo(ξ), s))
@test ncomponents(FT, S) == 2

# Call get_struct, and span `a` (access elements to 2^5 = 32.0):
@test DataLayouts.get_struct(a, S, Val(1), CI[1]) == Foo{FT}(1.0, 2.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(1), CI[2])

@test DataLayouts.get_struct(a, S, Val(2), CI[1]) == Foo{FT}(1.0, 3.0)
@test DataLayouts.get_struct(a, S, Val(2), CI[2]) == Foo{FT}(2.0, 4.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(2), CI[3])

@test DataLayouts.get_struct(a, S, Val(3), CI[1]) == Foo{FT}(1.0, 5.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[2]) == Foo{FT}(2.0, 6.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[3]) == Foo{FT}(3.0, 7.0)
@test DataLayouts.get_struct(a, S, Val(3), CI[4]) == Foo{FT}(4.0, 8.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(3), CI[5])

# e.g., how VIJFH is indexed (field dim 4)
@test DataLayouts.get_struct(a, S, Val(4), CI[1]) == Foo{FT}(1.0, 9.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[2]) == Foo{FT}(2.0, 10.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[3]) == Foo{FT}(3.0, 11.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[4]) == Foo{FT}(4.0, 12.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[5]) == Foo{FT}(5.0, 13.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[6]) == Foo{FT}(6.0, 14.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[7]) == Foo{FT}(7.0, 15.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[8]) == Foo{FT}(8.0, 16.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[9])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[10])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[11])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[12])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[13])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[14])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[15])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[16])
@test DataLayouts.get_struct(a, S, Val(4), CI[17]) == Foo{FT}(17.0, 25.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[18]) == Foo{FT}(18.0, 26.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[19]) == Foo{FT}(19.0, 27.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[20]) == Foo{FT}(20.0, 28.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[21]) == Foo{FT}(21.0, 29.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[22]) == Foo{FT}(22.0, 30.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[23]) == Foo{FT}(23.0, 31.0)
@test DataLayouts.get_struct(a, S, Val(4), CI[24]) == Foo{FT}(24.0, 32.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[25])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[26])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[27])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[28])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[29])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[30])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[31])
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(4), CI[32])

@test DataLayouts.get_struct(a, S, Val(5), CI[1]) == Foo{FT}(1.0, 17.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[2]) == Foo{FT}(2.0, 18.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[3]) == Foo{FT}(3.0, 19.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[4]) == Foo{FT}(4.0, 20.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[5]) == Foo{FT}(5.0, 21.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[6]) == Foo{FT}(6.0, 22.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[7]) == Foo{FT}(7.0, 23.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[8]) == Foo{FT}(8.0, 24.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[9]) == Foo{FT}(9.0, 25.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[10]) == Foo{FT}(10.0, 26.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[11]) == Foo{FT}(11.0, 27.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[12]) == Foo{FT}(12.0, 28.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[13]) == Foo{FT}(13.0, 29.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[14]) == Foo{FT}(14.0, 30.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[15]) == Foo{FT}(15.0, 31.0)
@test DataLayouts.get_struct(a, S, Val(5), CI[16]) == Foo{FT}(16.0, 32.0)
@test_throws BoundsError DataLayouts.get_struct(a, S, Val(5), CI[17])
@test get_struct(a, S, Val(4), CI(s)[1]) == Foo{FT}(1.0, 9.0)
@test get_struct(a, S, Val(4), CI(s)[2]) == Foo{FT}(2.0, 10.0)
@test get_struct(a, S, Val(4), CI(s)[3]) == Foo{FT}(3.0, 11.0)
@test get_struct(a, S, Val(4), CI(s)[4]) == Foo{FT}(4.0, 12.0)
@test get_struct(a, S, Val(4), CI(s)[5]) == Foo{FT}(5.0, 13.0)
@test get_struct(a, S, Val(4), CI(s)[6]) == Foo{FT}(6.0, 14.0)
@test get_struct(a, S, Val(4), CI(s)[7]) == Foo{FT}(7.0, 15.0)
@test get_struct(a, S, Val(4), CI(s)[8]) == Foo{FT}(8.0, 16.0)
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[9])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[10])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[11])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[12])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[13])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[14])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[15])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[16])
@test get_struct(a, S, Val(4), CI(s)[17]) == Foo{FT}(17.0, 25.0)
@test get_struct(a, S, Val(4), CI(s)[18]) == Foo{FT}(18.0, 26.0)
@test get_struct(a, S, Val(4), CI(s)[19]) == Foo{FT}(19.0, 27.0)
@test get_struct(a, S, Val(4), CI(s)[20]) == Foo{FT}(20.0, 28.0)
@test get_struct(a, S, Val(4), CI(s)[21]) == Foo{FT}(21.0, 29.0)
@test get_struct(a, S, Val(4), CI(s)[22]) == Foo{FT}(22.0, 30.0)
@test get_struct(a, S, Val(4), CI(s)[23]) == Foo{FT}(23.0, 31.0)
@test get_struct(a, S, Val(4), CI(s)[24]) == Foo{FT}(24.0, 32.0)
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[25])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[26])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[27])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[28])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[29])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[30])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[31])
@test_throws BoundsError get_struct(a, S, Val(4), CI(s)[32])
end

# TODO: add set_struct!

0 comments on commit bf4d141

Please sign in to comment.