Skip to content

Commit

Permalink
Merge pull request #2033 from CliMA/ck/datalayout_kw_constructors
Browse files Browse the repository at this point in the history
Add DataLayouts convenience constructors
  • Loading branch information
charleskawczynski authored Oct 11, 2024
2 parents 0b6aa34 + 4474f47 commit 49d7e80
Show file tree
Hide file tree
Showing 16 changed files with 569 additions and 340 deletions.
272 changes: 269 additions & 3 deletions src/DataLayouts/DataLayouts.jl

Large diffs are not rendered by default.

25 changes: 12 additions & 13 deletions test/DataLayouts/benchmark_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,38 +40,37 @@ end

@testset "copyto! with Nf = 1" begin
device = ClimaComms.device()
device_zeros(args...) = ClimaComms.array_type(device)(zeros(args...))
ArrayType = ClimaComms.array_type(device)
FT = Float64
S = FT
Nf = 1
Nv = 63
Nij = 4
Ni = Nij = 4
Nh = 30 * 30 * 6
Nk = 6
bm = Benchmark(; float_type = FT, device_name)
data = DataF{S}(device_zeros(FT, Nf))
data = DataF{S}(ArrayType{FT}, zeros)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IJFH{S, Nij}(device_zeros(FT, Nij, Nij, Nf, Nh))
data = IJFH{S}(ArrayType{FT}, zeros; Nij, Nh)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IFH{S, Nij}(device_zeros(FT, Nij, Nf, Nh))
data = IFH{S}(ArrayType{FT}, zeros; Ni, Nh)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
# The parent array of IJF and IF datalayouts are MArrays, and can therefore not bm, be passed into CUDA kernels on the RHS.
# data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3)
# data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3)
data = VF{S, Nv}(device_zeros(FT, Nv, Nf))
# data = IJF{S}(ArrayType{FT}, zeros; Nij); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3)
# data = IF{S}(ArrayType{FT}, zeros; Ni); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3)
data = VF{S}(ArrayType{FT}, zeros; Nv)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VIJFH{S, Nv, Nij}(device_zeros(FT, Nv, Nij, Nij, Nf, Nh))
data = VIJFH{S}(ArrayType{FT}, zeros; Nv, Nij, Nh)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VIFH{S, Nv, Nij}(device_zeros(FT, Nv, Nij, Nf, Nh))
data = VIFH{S}(ArrayType{FT}, zeros; Nv, Ni, Nh)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)

# data = IJKFVH{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IH1JH2{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nh); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IH1JH2{S}(ArrayType{FT}, zeros; Nij,Nk,Nh); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
tabulate_benchmark(bm)
end
43 changes: 21 additions & 22 deletions test/DataLayouts/benchmark_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end

include(joinpath(pkgdir(ClimaCore), "benchmarks/scripts/benchmark_utils.jl"))

function benchmarkfill!(bm, device, data, val, name)
function benchmarkfill!(bm, device, data, val)
caller = string(nameof(typeof(data)))
@info "Benchmarking $caller..."
trial = @benchmark ClimaComms.@cuda_sync $device fill!($data, $val)
Expand All @@ -37,41 +37,40 @@ end

@testset "fill! with Nf = 1" begin
device = ClimaComms.device()
device_zeros(args...) = ClimaComms.array_type(device)(zeros(args...))
ArrayType = ClimaComms.array_type(device)
FT = Float64
S = FT
Nf = 1
Nv = 63
Nij = 4
Ni = Nij = 4
Nh = 30 * 30 * 6
Nk = 6
bm = Benchmark(; float_type = FT, device_name)
data = DataF{S}(device_zeros(FT, Nf))
benchmarkfill!(bm, device, data, 3, "DataF")
data = DataF{S}(ArrayType{FT}, zeros)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IJFH{S, Nij}(device_zeros(FT, Nij, Nij, Nf, Nh))
benchmarkfill!(bm, device, data, 3, "IJFH")
data = IJFH{S}(ArrayType{FT}, zeros; Nij, Nh)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IFH{S, Nij}(device_zeros(FT, Nij, Nf, Nh))
benchmarkfill!(bm, device, data, 3, "IFH")
data = IFH{S}(ArrayType{FT}, zeros; Ni, Nh)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IJF{S, Nij}(device_zeros(FT, Nij, Nij, Nf))
benchmarkfill!(bm, device, data, 3, "IJF")
data = IJF{S}(ArrayType{FT}, zeros; Nij)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IF{S, Nij}(device_zeros(FT, Nij, Nf))
benchmarkfill!(bm, device, data, 3, "IF")
data = IF{S}(ArrayType{FT}, zeros; Ni)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VF{S, Nv}(device_zeros(FT, Nv, Nf))
benchmarkfill!(bm, device, data, 3, "VF")
data = VF{S}(ArrayType{FT}, zeros; Nv)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VIJFH{S, Nv, Nij}(device_zeros(FT, Nv, Nij, Nij, Nf, Nh))
benchmarkfill!(bm, device, data, 3, "VIJFH")
data = VIJFH{S}(ArrayType{FT}, zeros; Nv, Nij, Nh)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VIFH{S, Nv, Nij}(device_zeros(FT, Nv, Nij, Nf, Nh))
benchmarkfill!(bm, device, data, 3, "VIFH")
data = VIFH{S}(ArrayType{FT}, zeros; Nv, Ni, Nh)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)

# data = IJKFVH{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IH1JH2{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = DataLayouts.IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nv,Nh); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
tabulate_benchmark(bm)
end
21 changes: 8 additions & 13 deletions test/DataLayouts/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ end
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
Nh = 10
src = IJFH{S, 4}(ArrayType(rand(4, 4, 3, Nh)))
dst = IJFH{S, 4}(ArrayType(zeros(4, 4, 3, Nh)))
src = IJFH{S}(ArrayType{Float64}, rand; Nij = 4, Nh)
dst = IJFH{S}(ArrayType{Float64}, zeros; Nij = 4, Nh)

test_copy!(dst, src)

Expand All @@ -47,21 +47,17 @@ end
Nh = 2
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
data_arr1 = ArrayType(ones(FT, 2, 2, 3, Nh))
data_arr2 = ArrayType(ones(FT, 2, 2, 1, Nh))
data1 = IJFH{S1, 2}(data_arr1)
data2 = IJFH{S2, 2}(data_arr2)
data1 = IJFH{S1}(ArrayType{FT}, ones; Nij = 2, Nh)
data2 = IJFH{S2}(ArrayType{FT}, ones; Nij = 2, Nh)

f1(a1, a2) = a1.a.re * a2 + a1.b
res = f1.(data1, data2)
@test res isa IJFH{Float64}
@test Array(parent(res)) == FT[2 for i in 1:2, j in 1:2, f in 1:1, h in 1:2]

Nv = 33
data_arr1 = ArrayType(ones(FT, Nv, 4, 4, 3, 2))
data_arr2 = ArrayType(ones(FT, Nv, 4, 4, 1, 2))
data1 = VIJFH{S1, Nv, 4}(data_arr1)
data2 = VIJFH{S2, Nv, 4}(data_arr2)
data1 = VIJFH{S1}(ArrayType{FT}, ones; Nv, Nij = 4, Nh = 2)
data2 = VIJFH{S2}(ArrayType{FT}, ones; Nv, Nij = 4, Nh = 2)

f2(a1, a2) = a1.a.re * a2 + a1.b
res = f2.(data1, data2)
Expand All @@ -77,14 +73,13 @@ end
Nh = 3
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
array = similar(ArrayType{FT}, 2, 2, 2, Nh)
data = IJFH{S, 2}(array)
data = IJFH{S}(ArrayType{FT}; Nij = 2, Nh)
data .= Complex(1.0, 2.0)
@test Array(parent(data)) ==
FT[f == 1 ? 1 : 2 for i in 1:2, j in 1:2, f in 1:2, h in 1:3]

Nv = 33
data = VIJFH{S, Nv, 4}(ArrayType{FT}(undef, Nv, 4, 4, 2, Nh))
data = VIJFH{S}(ArrayType{FT}; Nv, Nij = 4, Nh)
data .= Complex(1.0, 2.0)
@test Array(parent(data)) == FT[
f == 1 ? 1 : 2 for v in 1:Nv, i in 1:4, j in 1:4, f in 1:2, h in 1:3
Expand Down
74 changes: 37 additions & 37 deletions test/DataLayouts/data0d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ using Revise; include(joinpath("test", "DataLayouts", "data0d.jl"))
using Test
using JET

using ClimaComms
using ClimaCore.DataLayouts
using StaticArrays
using ClimaCore.DataLayouts: get_struct, set_struct!

TestFloatTypes = (Float32, Float64)
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)

@testset "DataF" begin
for FT in TestFloatTypes
S = Tuple{Complex{FT}, FT}
array = rand(FT, 3)

data = DataF{S}(array)
data = DataF{S}(ArrayType{FT}, rand)
array = parent(data)
@test getfield(data, :array) == array

# test tuple assignment
Expand All @@ -42,8 +45,7 @@ end

@testset "DataF boundscheck" begin
S = Tuple{Complex{Float64}, Float64}
array = zeros(Float64, 3)
data = DataF{S}(array)
data = DataF{S}(ArrayType{Float64}, zeros)
@test data[][2] == zero(Float64)
@test_throws MethodError data[1]
end
Expand All @@ -53,8 +55,7 @@ end
SA = (a = 1.0, b = 2.0)
SB = (c = 1.0, d = 2.0)

array = zeros(Float64, 2)
data = DataF{typeof(SA)}(array)
data = DataF{typeof(SA)}(ArrayType{Float64}, zeros)

ret = begin
data[] = SA
Expand All @@ -66,9 +67,8 @@ end

@testset "DataF broadcasting between 0D data objects and scalars" begin
for FT in TestFloatTypes
data1 = ones(FT, 2)
S = Complex{FT}
data1 = DataF{S}(data1)
data1 = DataF{S}(ArrayType{FT}, ones)
res = data1 .+ 1
@test res isa DataF
@test parent(res) == FT[2.0, 1.0]
Expand All @@ -91,12 +91,10 @@ end

@testset "DataF broadcasting between 0D data objects" begin
for FT in TestFloatTypes
data1 = ones(FT, 2)
data2 = ones(FT, 1)
S1 = Complex{FT}
S2 = FT
data1 = DataF{S1}(data1)
data2 = DataF{S2}(data2)
data1 = DataF{S1}(ArrayType{FT}, ones)
data2 = DataF{S2}(ArrayType{FT}, ones)
res = data1 .+ data2
@test res isa DataF{S1}
@test parent(res) == FT[2.0, 1.0]
Expand All @@ -108,8 +106,8 @@ end
FT = Float64
S = Complex{FT}
Nv = 3
data_f = DataF{S}(ones(FT, 2))
data_vf = VF{S, Nv}(ones(FT, Nv, 2))
data_f = DataF{S}(ArrayType{FT}, ones)
data_vf = VF{S}(ArrayType{FT}, ones; Nv)
data_vf2 = data_f .+ data_vf
@test data_vf2 isa VF{S, Nv}
@test size(data_vf2) == (1, 1, 1, 3, 1)
Expand All @@ -118,8 +116,8 @@ end
@testset "broadcasting DataF + IF data object => IF" begin
FT = Float64
S = Complex{FT}
data_f = DataF{S}(ones(FT, 2))
data_if = IF{S, 2}(ones(FT, 2, 2))
data_f = DataF{S}(ArrayType{FT}, ones)
data_if = IF{S}(ArrayType{FT}, ones; Ni = 2)
data_if2 = data_f .+ data_if
@test data_if2 isa IF{S}
@test size(data_if2) == (2, 1, 1, 1, 1)
Expand All @@ -129,8 +127,8 @@ end
FT = Float64
S = Complex{FT}
Nh = 3
data_f = DataF{S}(ones(FT, 2))
data_ifh = IFH{S, 2}(ones(FT, 2, 2, Nh))
data_f = DataF{S}(ArrayType{FT}, ones)
data_ifh = IFH{S}(ArrayType{FT}, ones; Ni = 2, Nh)
data_ifh2 = data_f .+ data_ifh
@test data_ifh2 isa IFH{S}
@test size(data_ifh2) == (2, 1, 1, 1, 3)
Expand All @@ -139,8 +137,8 @@ end
@testset "broadcasting DataF + IJF data object => IJF" begin
FT = Float64
S = Complex{FT}
data_f = DataF{S}(ones(FT, 2))
data_ijf = IJF{S, 2}(ones(FT, 2, 2, 2))
data_f = DataF{S}(ArrayType{FT}, ones)
data_ijf = IJF{S}(ArrayType{FT}, ones; Nij = 2)
data_ijf2 = data_f .+ data_ijf
@test data_ijf2 isa IJF{S}
@test size(data_ijf2) == (2, 2, 1, 1, 1)
Expand All @@ -150,8 +148,8 @@ end
FT = Float64
S = Complex{FT}
Nh = 3
data_f = DataF{S}(ones(FT, 2))
data_ijfh = IJFH{S, 2}(ones(2, 2, 2, Nh))
data_f = DataF{S}(ArrayType{FT}, ones)
data_ijfh = IJFH{S}(ArrayType{FT}, ones; Nij = 2, Nh)
data_ijfh2 = data_f .+ data_ijfh
@test data_ijfh2 isa IJFH{S}
@test size(data_ijfh2) == (2, 2, 1, 1, Nh)
Expand All @@ -161,9 +159,9 @@ end
FT = Float64
S = Complex{FT}
Nh = 10
data_f = DataF{S}(ones(FT, 2))
data_f = DataF{S}(ArrayType{FT}, ones)
Nv = 10
data_vifh = VIFH{S, Nv, 4}(ones(FT, Nv, 4, 2, Nh))
data_vifh = VIFH{S}(ArrayType{FT}, ones; Nv, Ni = 4, Nh)
data_vifh2 = data_f .+ data_vifh
@test data_vifh2 isa VIFH{S, Nv}
@test size(data_vifh2) == (4, 1, 1, Nv, Nh)
Expand All @@ -174,8 +172,8 @@ end
S = Complex{FT}
Nv = 2
Nh = 2
data_f = DataF{S}(ones(FT, 2))
data_vijfh = VIJFH{S, Nv, 2}(ones(FT, Nv, 2, 2, 2, Nh))
data_f = DataF{S}(ArrayType{FT}, ones)
data_vijfh = VIJFH{S}(ArrayType{FT}, ones; Nv, Nij = 2, Nh)
data_vijfh2 = data_f .+ data_vijfh
@test data_vijfh2 isa VIJFH{S, Nv}
@test size(data_vijfh2) == (2, 2, 1, Nv, Nh)
Expand All @@ -184,8 +182,9 @@ end
@testset "column IF => DataF" begin
FT = Float64
S = Complex{FT}
array = FT[1 2; 3 4]
data_if = IF{S, 2}(array)
data_if = IF{S}(ArrayType{FT}; Ni = 2)
array = parent(data_if)
array .= FT[1 2; 3 4]
if_column = column(data_if, 2)
@test if_column isa DataF
@test if_column[] == 3.0 + 4.0im
Expand All @@ -196,9 +195,9 @@ end
FT = Float64
S = Complex{FT}
Nh = 3
array = ones(FT, 2, 2, Nh)
data_ifh = IFH{S}(ArrayType{FT}; Ni = 2, Nh)
array = parent(data_ifh)
array[1, :, 1] .= FT[3, 4]
data_ifh = IFH{S, 2}(array)
ifh_column = column(data_ifh, 1, 1)
@test ifh_column isa DataF
@test ifh_column[] == 3.0 + 4.0im
Expand All @@ -209,9 +208,9 @@ end
@testset "column IJF => DataF" begin
FT = Float64
S = Complex{FT}
array = ones(FT, 2, 2, 2)
data_ijf = IJF{S}(ArrayType{FT}; Nij = 2)
array = parent(data_ijf)
array[1, 1, :] .= FT[3, 4]
data_ijf = IJF{S, 2}(array)
ijf_column = column(data_ijf, 1, 1)
@test ijf_column isa DataF
@test ijf_column[] == 3.0 + 4.0im
Expand All @@ -223,9 +222,9 @@ end
FT = Float64
S = Complex{FT}
Nh = 3
array = ones(2, 2, 2, 3)
data_ijfh = IJFH{S}(ArrayType{FT}; Nij = 2, Nh)
array = parent(data_ijfh)
array[1, 1, :, 2] .= FT[3, 4]
data_ijfh = IJFH{S, 2}(array)
ijfh_column = column(data_ijfh, 1, 1, 2)
@test ijfh_column isa DataF
@test ijfh_column[] == 3.0 + 4.0im
Expand All @@ -237,9 +236,10 @@ end
@testset "level VF => DataF" begin
FT = Float64
S = Complex{FT}
array = FT[1 2; 3 4; 5 6]
Nv = size(array, 1)
data_vf = VF{S, Nv}(array)
Nv = 3
data_vf = VF{S}(ArrayType{FT}; Nv)
array = parent(data_vf)
array .= FT[1 2; 3 4; 5 6]
vf_level = level(data_vf, 2)
@test vf_level isa DataF
@test vf_level[] == 3.0 + 4.0im
Expand Down
Loading

0 comments on commit 49d7e80

Please sign in to comment.