Skip to content

Commit

Permalink
Use UniversalSize struct in more kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 24, 2024
1 parent 4736aea commit 82151d9
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 28 deletions.
4 changes: 4 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ steps:
key: data_opt_similar
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/opt_similar.jl"

- label: "Unit: opt_universal_size"
key: opt_universal_size
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/opt_universal_size.jl"

- label: "Unit: data_ndims"
key: unit_data_ndims
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/DataLayouts/unit_ndims.jl"
Expand Down
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import ClimaCore.Utilities: half
import ClimaCore.Utilities: cart_ind, linear_ind
import ClimaCore.RecursiveApply:
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh

include(joinpath("cuda", "cuda_utils.jl"))
include(joinpath("cuda", "data_layouts.jl"))
Expand Down
11 changes: 6 additions & 5 deletions ext/cuda/data_layouts_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ function Base.copyto!(
end

import ClimaCore.DataLayouts: isascalar
function knl_copyto_flat!(dest::AbstractData, bc)
function knl_copyto_flat!(dest::AbstractData, bc, us)
@inbounds begin
n = size(dest)
tidx = thread_index()
if valid_range(tidx, prod(n))
if tidx get_N(us)
n = size(dest)
I = kernel_indexes(tidx, n)
dest[I] = bc[I]
end
Expand All @@ -98,14 +98,15 @@ end

function cuda_copyto!(dest::AbstractData, bc)
(_, _, Nv, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
auto_launch!(knl_copyto_flat!, (dest, bc), dest; auto = true)
auto_launch!(knl_copyto_flat!, (dest, bc, us), dest; auto = true)
end
return dest
end

# TODO: can we use CUDA's luanch configuration for all data layouts?
# Currently, it seems to have a slight performance degredation.
# Currently, it seems to have a slight performance degradation.
#! format: off
# Base.copyto!(dest::IJFH{S, Nij}, bc::DataLayouts.BroadcastedUnionIJFH{S, Nij, Nh}, ::ToCUDA) where {S, Nij, Nh} = cuda_copyto!(dest, bc)
Base.copyto!(dest::IFH{S, Ni, Nh}, bc::DataLayouts.BroadcastedUnionIFH{S, Ni, Nh}, ::ToCUDA) where {S, Ni, Nh} = cuda_copyto!(dest, bc)
Expand Down
9 changes: 5 additions & 4 deletions ext/cuda/data_layouts_fill.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
function knl_fill_flat!(dest::AbstractData, val)
function knl_fill_flat!(dest::AbstractData, val, us)
@inbounds begin
tidx = thread_index()
n = size(dest)
if valid_range(tidx, prod(n))
if tidx get_N(us)
n = size(dest)
I = kernel_indexes(tidx, n)
@inbounds dest[I] = val
end
Expand All @@ -12,8 +12,9 @@ end

function cuda_fill!(dest::AbstractData, val)
(_, _, Nv, Nh) = DataLayouts.universal_size(dest)
us = DataLayouts.UniversalSize(dest)
if Nv > 0 && Nh > 0
auto_launch!(knl_fill_flat!, (dest, val), dest; auto = true)
auto_launch!(knl_fill_flat!, (dest, val, us), dest; auto = true)
end
return dest
end
Expand Down
28 changes: 9 additions & 19 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,11 @@ function Base.copyto!(
(li, lw, rw, ri) = bounds = Operators.window_bounds(space, bc)
Nv = ri - li + 1
max_threads = 256
us = DataLayouts.UniversalSize(Fields.field_values(out))
nitems = Nv * Nq * Nq * Nh # # of independent items
(nthreads, nblocks) = _configure_threadblock(max_threads, nitems)
args = (
strip_space(out, space),
strip_space(bc, space),
axes(out),
bounds,
Val(Nv),
Val(Nq),
Nh,
)
args =
(strip_space(out, space), strip_space(bc, space), axes(out), bounds, us)
auto_launch!(
copyto_stencil_kernel!,
args,
Expand All @@ -49,20 +43,16 @@ function Base.copyto!(
)
return out
end
import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh

function copyto_stencil_kernel!(
out,
bc,
space,
bds,
::Val{Nv},
::Val{Nq},
Nh,
) where {Nv, Nq}
function copyto_stencil_kernel!(out, bc, space, bds, us)
@inbounds begin
gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x
if gid Nv * Nq * Nq * Nh
if gid get_N(us)
(li, lw, rw, ri) = bds
Nv = get_Nv(us)
Nq = get_Nij(us)
Nh = get_Nh(us)
(v, i, j, h) = cart_ind((Nv, Nq, Nq, Nh), gid).I
hidx = (i, j, h)
idx = v - 1 + li
Expand Down
35 changes: 35 additions & 0 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,41 @@ corresponding to `UniversalSize`.
"""
@inline universal_size(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} =
(Ni, Nj, Nv, Nh)

"""
get_N(::AbstractData)
get_N(::UniversalSize)
Statically returns `prod((Ni, Nj, Nv, Nh))`
"""
@inline get_N(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} =
prod((Ni, Nj, Nv, Nh))

@inline get_N(data::AbstractData) = get_N(UniversalSize(data))

"""
get_Nv(::UniversalSize)
Statically returns `Nv`.
"""
get_Nv(::UniversalSize{Ni, Nj, Nv}) where {Ni, Nj, Nv} = Nv

"""
get_Nij(::UniversalSize)
Statically returns `Nij`.
"""
get_Nij(::UniversalSize{Nij}) where {Nij} = Nij

"""
get_Nh(::UniversalSize)
Statically returns `Nh`.
"""
get_Nh(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} = Nh

get_Nh(data::AbstractData) = Nh

@inline universal_size(data::AbstractData) = universal_size(UniversalSize(data))

function Base.show(io::IO, data::AbstractData)
Expand Down
68 changes: 68 additions & 0 deletions test/DataLayouts/opt_universal_size.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#=
julia --project
using Revise; include(joinpath("test", "DataLayouts", "opt_universal_size.jl"))
=#
using Test
using ClimaCore.DataLayouts
using ClimaCore: DataLayouts, Geometry
import ClimaComms
using StaticArrays: SMatrix
ClimaComms.@import_required_backends
using JET
using InteractiveUtils: @code_typed

function test_universal_size(data)
us = DataLayouts.UniversalSize(data)
# Make sure results is statically returned / constant propagated
ct = @code_typed DataLayouts.get_N(us)
@test ct.first.code[1] isa Core.ReturnNode
@test ct.first.code[end].val == DataLayouts.get_N(us)

ct = @code_typed DataLayouts.get_Nv(us)
@test ct.first.code[1] isa Core.ReturnNode
@test ct.first.code[end].val == DataLayouts.get_Nv(us)

ct = @code_typed DataLayouts.get_Nij(us)
@test ct.first.code[1] isa Core.ReturnNode
@test ct.first.code[end].val == DataLayouts.get_Nij(us)

ct = @code_typed DataLayouts.get_Nh(us)
@test ct.first.code[1] isa Core.ReturnNode
@test ct.first.code[end].val == DataLayouts.get_Nh(us)

ct = @code_typed size(data)
@test ct.first.code[1] isa Core.ReturnNode
@test ct.first.code[end].val == size(data)

ct = @code_typed DataLayouts.get_N(data)
@test ct.first.code[1] isa Core.ReturnNode
@test ct.first.code[end].val == DataLayouts.get_N(data)

# Demo of failed constant prop:
ct = @code_typed prod(size(data))
@test ct.first.code[1] isa Expr # first element is not a return node, but an expression
end

@testset "UniversalSize" begin
device = ClimaComms.device()
device_zeros(args...) = ClimaComms.array_type(device)(zeros(args...))
FT = Float64
S = FT
Nf = 1
Nv = 4
Nij = 3
Nh = 5
Nk = 6
#! format: off
data = DataF{S}(device_zeros(FT,Nf)); test_universal_size(data)
data = IJFH{S, Nij, Nh}(device_zeros(FT,Nij,Nij,Nf,Nh)); test_universal_size(data)
data = IFH{S, Nij, Nh}(device_zeros(FT,Nij,Nf,Nh)); test_universal_size(data)
data = IJF{S, Nij}(device_zeros(FT,Nij,Nij,Nf)); test_universal_size(data)
data = IF{S, Nij}(device_zeros(FT,Nij,Nf)); test_universal_size(data)
data = VF{S, Nv}(device_zeros(FT,Nv,Nf)); test_universal_size(data)
data = VIJFH{S,Nv,Nij,Nh}(device_zeros(FT,Nv,Nij,Nij,Nf,Nh));test_universal_size(data)
data = VIFH{S, Nv, Nij, Nh}(device_zeros(FT,Nv,Nij,Nf,Nh)); test_universal_size(data)
#! format: on
# data = DataLayouts.IJKFVH{S, Nij, Nk, Nv, Nh}(device_zeros(FT,Nij,Nij,Nk,Nf,Nv,Nh)); test_universal_size(data) # TODO: test
# data = DataLayouts.IH1JH2{S, Nij}(device_zeros(FT,2*Nij,3*Nij)); test_universal_size(data) # TODO: test
end

0 comments on commit 82151d9

Please sign in to comment.