Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DataLayouts convenience constructors #2033

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 269 additions & 3 deletions src/DataLayouts/DataLayouts.jl

Large diffs are not rendered by default.

25 changes: 12 additions & 13 deletions test/DataLayouts/benchmark_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,38 +40,37 @@ end

@testset "copyto! with Nf = 1" begin
device = ClimaComms.device()
device_zeros(args...) = ClimaComms.array_type(device)(zeros(args...))
ArrayType = ClimaComms.array_type(device)
FT = Float64
S = FT
Nf = 1
Nv = 63
Nij = 4
Ni = Nij = 4
Nh = 30 * 30 * 6
Nk = 6
bm = Benchmark(; float_type = FT, device_name)
data = DataF{S}(device_zeros(FT, Nf))
data = DataF{S}(ArrayType{FT}, zeros)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IJFH{S, Nij}(device_zeros(FT, Nij, Nij, Nf, Nh))
data = IJFH{S}(ArrayType{FT}, zeros; Nij, Nh)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IFH{S, Nij}(device_zeros(FT, Nij, Nf, Nh))
data = IFH{S}(ArrayType{FT}, zeros; Ni, Nh)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
# The parent array of IJF and IF datalayouts are MArrays, and can therefore not bm, be passed into CUDA kernels on the RHS.
# data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3)
# data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3)
data = VF{S, Nv}(device_zeros(FT, Nv, Nf))
# data = IJF{S}(ArrayType{FT}, zeros; Nij); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3)
# data = IF{S}(ArrayType{FT}, zeros; Ni); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3)
data = VF{S}(ArrayType{FT}, zeros; Nv)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VIJFH{S, Nv, Nij}(device_zeros(FT, Nv, Nij, Nij, Nf, Nh))
data = VIJFH{S}(ArrayType{FT}, zeros; Nv, Nij, Nh)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VIFH{S, Nv, Nij}(device_zeros(FT, Nv, Nij, Nf, Nh))
data = VIFH{S}(ArrayType{FT}, zeros; Nv, Ni, Nh)
benchmarkcopyto!(bm, device, data, 3)
@test all(parent(data) .== 3)

# data = IJKFVH{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IH1JH2{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nh); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IH1JH2{S}(ArrayType{FT}, zeros; Nij,Nk,Nh); benchmarkcopyto!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
tabulate_benchmark(bm)
end
43 changes: 21 additions & 22 deletions test/DataLayouts/benchmark_fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ end

include(joinpath(pkgdir(ClimaCore), "benchmarks/scripts/benchmark_utils.jl"))

function benchmarkfill!(bm, device, data, val, name)
function benchmarkfill!(bm, device, data, val)
caller = string(nameof(typeof(data)))
@info "Benchmarking $caller..."
trial = @benchmark ClimaComms.@cuda_sync $device fill!($data, $val)
Expand All @@ -37,41 +37,40 @@ end

@testset "fill! with Nf = 1" begin
device = ClimaComms.device()
device_zeros(args...) = ClimaComms.array_type(device)(zeros(args...))
ArrayType = ClimaComms.array_type(device)
FT = Float64
S = FT
Nf = 1
Nv = 63
Nij = 4
Ni = Nij = 4
Nh = 30 * 30 * 6
Nk = 6
bm = Benchmark(; float_type = FT, device_name)
data = DataF{S}(device_zeros(FT, Nf))
benchmarkfill!(bm, device, data, 3, "DataF")
data = DataF{S}(ArrayType{FT}, zeros)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IJFH{S, Nij}(device_zeros(FT, Nij, Nij, Nf, Nh))
benchmarkfill!(bm, device, data, 3, "IJFH")
data = IJFH{S}(ArrayType{FT}, zeros; Nij, Nh)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IFH{S, Nij}(device_zeros(FT, Nij, Nf, Nh))
benchmarkfill!(bm, device, data, 3, "IFH")
data = IFH{S}(ArrayType{FT}, zeros; Ni, Nh)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IJF{S, Nij}(device_zeros(FT, Nij, Nij, Nf))
benchmarkfill!(bm, device, data, 3, "IJF")
data = IJF{S}(ArrayType{FT}, zeros; Nij)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = IF{S, Nij}(device_zeros(FT, Nij, Nf))
benchmarkfill!(bm, device, data, 3, "IF")
data = IF{S}(ArrayType{FT}, zeros; Ni)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VF{S, Nv}(device_zeros(FT, Nv, Nf))
benchmarkfill!(bm, device, data, 3, "VF")
data = VF{S}(ArrayType{FT}, zeros; Nv)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VIJFH{S, Nv, Nij}(device_zeros(FT, Nv, Nij, Nij, Nf, Nh))
benchmarkfill!(bm, device, data, 3, "VIJFH")
data = VIJFH{S}(ArrayType{FT}, zeros; Nv, Nij, Nh)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)
data = VIFH{S, Nv, Nij}(device_zeros(FT, Nv, Nij, Nf, Nh))
benchmarkfill!(bm, device, data, 3, "VIFH")
data = VIFH{S}(ArrayType{FT}, zeros; Nv, Ni, Nh)
benchmarkfill!(bm, device, data, 3)
@test all(parent(data) .== 3)

# data = IJKFVH{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = IH1JH2{S}(device_zeros(FT,Nij,Nij,Nk,Nf,Nh)); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = DataLayouts.IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nv,Nh); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); benchmarkfill!(bm, device, data, 3); @test all(parent(data) .== 3) # TODO: test
tabulate_benchmark(bm)
end
21 changes: 8 additions & 13 deletions test/DataLayouts/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ end
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
Nh = 10
src = IJFH{S, 4}(ArrayType(rand(4, 4, 3, Nh)))
dst = IJFH{S, 4}(ArrayType(zeros(4, 4, 3, Nh)))
src = IJFH{S}(ArrayType{Float64}, rand; Nij = 4, Nh)
dst = IJFH{S}(ArrayType{Float64}, zeros; Nij = 4, Nh)

test_copy!(dst, src)

Expand All @@ -47,21 +47,17 @@ end
Nh = 2
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
data_arr1 = ArrayType(ones(FT, 2, 2, 3, Nh))
data_arr2 = ArrayType(ones(FT, 2, 2, 1, Nh))
data1 = IJFH{S1, 2}(data_arr1)
data2 = IJFH{S2, 2}(data_arr2)
data1 = IJFH{S1}(ArrayType{FT}, ones; Nij = 2, Nh)
data2 = IJFH{S2}(ArrayType{FT}, ones; Nij = 2, Nh)

f1(a1, a2) = a1.a.re * a2 + a1.b
res = f1.(data1, data2)
@test res isa IJFH{Float64}
@test Array(parent(res)) == FT[2 for i in 1:2, j in 1:2, f in 1:1, h in 1:2]

Nv = 33
data_arr1 = ArrayType(ones(FT, Nv, 4, 4, 3, 2))
data_arr2 = ArrayType(ones(FT, Nv, 4, 4, 1, 2))
data1 = VIJFH{S1, Nv, 4}(data_arr1)
data2 = VIJFH{S2, Nv, 4}(data_arr2)
data1 = VIJFH{S1}(ArrayType{FT}, ones; Nv, Nij = 4, Nh = 2)
data2 = VIJFH{S2}(ArrayType{FT}, ones; Nv, Nij = 4, Nh = 2)

f2(a1, a2) = a1.a.re * a2 + a1.b
res = f2.(data1, data2)
Expand All @@ -77,14 +73,13 @@ end
Nh = 3
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)
array = similar(ArrayType{FT}, 2, 2, 2, Nh)
data = IJFH{S, 2}(array)
data = IJFH{S}(ArrayType{FT}; Nij = 2, Nh)
data .= Complex(1.0, 2.0)
@test Array(parent(data)) ==
FT[f == 1 ? 1 : 2 for i in 1:2, j in 1:2, f in 1:2, h in 1:3]

Nv = 33
data = VIJFH{S, Nv, 4}(ArrayType{FT}(undef, Nv, 4, 4, 2, Nh))
data = VIJFH{S}(ArrayType{FT}; Nv, Nij = 4, Nh)
data .= Complex(1.0, 2.0)
@test Array(parent(data)) == FT[
f == 1 ? 1 : 2 for v in 1:Nv, i in 1:4, j in 1:4, f in 1:2, h in 1:3
Expand Down
74 changes: 37 additions & 37 deletions test/DataLayouts/data0d.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@ using Revise; include(joinpath("test", "DataLayouts", "data0d.jl"))
using Test
using JET

using ClimaComms
using ClimaCore.DataLayouts
using StaticArrays
using ClimaCore.DataLayouts: get_struct, set_struct!

TestFloatTypes = (Float32, Float64)
device = ClimaComms.device()
ArrayType = ClimaComms.array_type(device)

@testset "DataF" begin
for FT in TestFloatTypes
S = Tuple{Complex{FT}, FT}
array = rand(FT, 3)

data = DataF{S}(array)
data = DataF{S}(ArrayType{FT}, rand)
array = parent(data)
@test getfield(data, :array) == array

# test tuple assignment
Expand All @@ -42,8 +45,7 @@ end

@testset "DataF boundscheck" begin
S = Tuple{Complex{Float64}, Float64}
array = zeros(Float64, 3)
data = DataF{S}(array)
data = DataF{S}(ArrayType{Float64}, zeros)
@test data[][2] == zero(Float64)
@test_throws MethodError data[1]
end
Expand All @@ -53,8 +55,7 @@ end
SA = (a = 1.0, b = 2.0)
SB = (c = 1.0, d = 2.0)

array = zeros(Float64, 2)
data = DataF{typeof(SA)}(array)
data = DataF{typeof(SA)}(ArrayType{Float64}, zeros)

ret = begin
data[] = SA
Expand All @@ -66,9 +67,8 @@ end

@testset "DataF broadcasting between 0D data objects and scalars" begin
for FT in TestFloatTypes
data1 = ones(FT, 2)
S = Complex{FT}
data1 = DataF{S}(data1)
data1 = DataF{S}(ArrayType{FT}, ones)
res = data1 .+ 1
@test res isa DataF
@test parent(res) == FT[2.0, 1.0]
Expand All @@ -91,12 +91,10 @@ end

@testset "DataF broadcasting between 0D data objects" begin
for FT in TestFloatTypes
data1 = ones(FT, 2)
data2 = ones(FT, 1)
S1 = Complex{FT}
S2 = FT
data1 = DataF{S1}(data1)
data2 = DataF{S2}(data2)
data1 = DataF{S1}(ArrayType{FT}, ones)
data2 = DataF{S2}(ArrayType{FT}, ones)
res = data1 .+ data2
@test res isa DataF{S1}
@test parent(res) == FT[2.0, 1.0]
Expand All @@ -108,8 +106,8 @@ end
FT = Float64
S = Complex{FT}
Nv = 3
data_f = DataF{S}(ones(FT, 2))
data_vf = VF{S, Nv}(ones(FT, Nv, 2))
data_f = DataF{S}(ArrayType{FT}, ones)
data_vf = VF{S}(ArrayType{FT}, ones; Nv)
data_vf2 = data_f .+ data_vf
@test data_vf2 isa VF{S, Nv}
@test size(data_vf2) == (1, 1, 1, 3, 1)
Expand All @@ -118,8 +116,8 @@ end
@testset "broadcasting DataF + IF data object => IF" begin
FT = Float64
S = Complex{FT}
data_f = DataF{S}(ones(FT, 2))
data_if = IF{S, 2}(ones(FT, 2, 2))
data_f = DataF{S}(ArrayType{FT}, ones)
data_if = IF{S}(ArrayType{FT}, ones; Ni = 2)
data_if2 = data_f .+ data_if
@test data_if2 isa IF{S}
@test size(data_if2) == (2, 1, 1, 1, 1)
Expand All @@ -129,8 +127,8 @@ end
FT = Float64
S = Complex{FT}
Nh = 3
data_f = DataF{S}(ones(FT, 2))
data_ifh = IFH{S, 2}(ones(FT, 2, 2, Nh))
data_f = DataF{S}(ArrayType{FT}, ones)
data_ifh = IFH{S}(ArrayType{FT}, ones; Ni = 2, Nh)
data_ifh2 = data_f .+ data_ifh
@test data_ifh2 isa IFH{S}
@test size(data_ifh2) == (2, 1, 1, 1, 3)
Expand All @@ -139,8 +137,8 @@ end
@testset "broadcasting DataF + IJF data object => IJF" begin
FT = Float64
S = Complex{FT}
data_f = DataF{S}(ones(FT, 2))
data_ijf = IJF{S, 2}(ones(FT, 2, 2, 2))
data_f = DataF{S}(ArrayType{FT}, ones)
data_ijf = IJF{S}(ArrayType{FT}, ones; Nij = 2)
data_ijf2 = data_f .+ data_ijf
@test data_ijf2 isa IJF{S}
@test size(data_ijf2) == (2, 2, 1, 1, 1)
Expand All @@ -150,8 +148,8 @@ end
FT = Float64
S = Complex{FT}
Nh = 3
data_f = DataF{S}(ones(FT, 2))
data_ijfh = IJFH{S, 2}(ones(2, 2, 2, Nh))
data_f = DataF{S}(ArrayType{FT}, ones)
data_ijfh = IJFH{S}(ArrayType{FT}, ones; Nij = 2, Nh)
data_ijfh2 = data_f .+ data_ijfh
@test data_ijfh2 isa IJFH{S}
@test size(data_ijfh2) == (2, 2, 1, 1, Nh)
Expand All @@ -161,9 +159,9 @@ end
FT = Float64
S = Complex{FT}
Nh = 10
data_f = DataF{S}(ones(FT, 2))
data_f = DataF{S}(ArrayType{FT}, ones)
Nv = 10
data_vifh = VIFH{S, Nv, 4}(ones(FT, Nv, 4, 2, Nh))
data_vifh = VIFH{S}(ArrayType{FT}, ones; Nv, Ni = 4, Nh)
data_vifh2 = data_f .+ data_vifh
@test data_vifh2 isa VIFH{S, Nv}
@test size(data_vifh2) == (4, 1, 1, Nv, Nh)
Expand All @@ -174,8 +172,8 @@ end
S = Complex{FT}
Nv = 2
Nh = 2
data_f = DataF{S}(ones(FT, 2))
data_vijfh = VIJFH{S, Nv, 2}(ones(FT, Nv, 2, 2, 2, Nh))
data_f = DataF{S}(ArrayType{FT}, ones)
data_vijfh = VIJFH{S}(ArrayType{FT}, ones; Nv, Nij = 2, Nh)
data_vijfh2 = data_f .+ data_vijfh
@test data_vijfh2 isa VIJFH{S, Nv}
@test size(data_vijfh2) == (2, 2, 1, Nv, Nh)
Expand All @@ -184,8 +182,9 @@ end
@testset "column IF => DataF" begin
FT = Float64
S = Complex{FT}
array = FT[1 2; 3 4]
data_if = IF{S, 2}(array)
data_if = IF{S}(ArrayType{FT}; Ni = 2)
array = parent(data_if)
array .= FT[1 2; 3 4]
if_column = column(data_if, 2)
@test if_column isa DataF
@test if_column[] == 3.0 + 4.0im
Expand All @@ -196,9 +195,9 @@ end
FT = Float64
S = Complex{FT}
Nh = 3
array = ones(FT, 2, 2, Nh)
data_ifh = IFH{S}(ArrayType{FT}; Ni = 2, Nh)
array = parent(data_ifh)
array[1, :, 1] .= FT[3, 4]
data_ifh = IFH{S, 2}(array)
ifh_column = column(data_ifh, 1, 1)
@test ifh_column isa DataF
@test ifh_column[] == 3.0 + 4.0im
Expand All @@ -209,9 +208,9 @@ end
@testset "column IJF => DataF" begin
FT = Float64
S = Complex{FT}
array = ones(FT, 2, 2, 2)
data_ijf = IJF{S}(ArrayType{FT}; Nij = 2)
array = parent(data_ijf)
array[1, 1, :] .= FT[3, 4]
data_ijf = IJF{S, 2}(array)
ijf_column = column(data_ijf, 1, 1)
@test ijf_column isa DataF
@test ijf_column[] == 3.0 + 4.0im
Expand All @@ -223,9 +222,9 @@ end
FT = Float64
S = Complex{FT}
Nh = 3
array = ones(2, 2, 2, 3)
data_ijfh = IJFH{S}(ArrayType{FT}; Nij = 2, Nh)
array = parent(data_ijfh)
array[1, 1, :, 2] .= FT[3, 4]
data_ijfh = IJFH{S, 2}(array)
ijfh_column = column(data_ijfh, 1, 1, 2)
@test ijfh_column isa DataF
@test ijfh_column[] == 3.0 + 4.0im
Expand All @@ -237,9 +236,10 @@ end
@testset "level VF => DataF" begin
FT = Float64
S = Complex{FT}
array = FT[1 2; 3 4; 5 6]
Nv = size(array, 1)
data_vf = VF{S, Nv}(array)
Nv = 3
data_vf = VF{S}(ArrayType{FT}; Nv)
array = parent(data_vf)
array .= FT[1 2; 3 4; 5 6]
vf_level = level(data_vf, 2)
@test vf_level isa DataF
@test vf_level[] == 3.0 + 4.0im
Expand Down
Loading
Loading