Skip to content

Commit

Permalink
Merge pull request #1521 from CliMA/sk/cuda_mapreduce_mvar
Browse files Browse the repository at this point in the history
Fix CUDA mapreduce bug, affecting fields with multiple variables.
  • Loading branch information
sriharshakandala authored Nov 10, 2023
2 parents 4f35601 + 7bd78dc commit c76e128
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 24 deletions.
43 changes: 19 additions & 24 deletions src/Fields/mapreduce_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,42 +150,33 @@ function mapreduce_cuda_kernel!(
tidx = threadIdx().x
bidx = blockIdx().x
fidx = blockIdx().y
dataview = _dataview(pdata, fidx)
effective_blksize = blksize * (n_ops_on_load + 1)
gidx = _get_gidx(tidx, bidx, fidx, effective_blksize, nblk)
gidx = _get_gidx(tidx, bidx, effective_blksize)
reduction = CUDA.CuStaticSharedArray(T, shmemsize)
reduction[tidx] = 0
(Nv, Nij, Nf, Nh) = _get_dims(pdata)
(Nv, Nij, Nf, Nh) = _get_dims(dataview)
nitems = Nv * Nij * Nij * Nf * Nh

# load shmem
if gidx nitems
if weighting
reduction[tidx] = f(pdata[gidx]) * pwt[gidx]
reduction[tidx] = f(dataview[gidx]) * pwt[gidx]
for n_ops in 1:n_ops_on_load
gidx2 = _get_gidx(
tidx + blksize * n_ops,
bidx,
fidx,
effective_blksize,
nblk,
)
gidx2 =
_get_gidx(tidx + blksize * n_ops, bidx, effective_blksize)
if gidx2 nitems
reduction[tidx] =
op(reduction[tidx], f(pdata[gidx2]) * pwt[gidx2])
op(reduction[tidx], f(dataview[gidx2]) * pwt[gidx2])
end
end
else
reduction[tidx] = f(pdata[gidx])
reduction[tidx] = f(dataview[gidx])
for n_ops in 1:n_ops_on_load
gidx2 = _get_gidx(
tidx + blksize * n_ops,
bidx,
fidx,
effective_blksize,
nblk,
)
gidx2 =
_get_gidx(tidx + blksize * n_ops, bidx, effective_blksize)
if gidx2 nitems
reduction[tidx] = op(reduction[tidx], f(pdata[gidx2]))
reduction[tidx] = op(reduction[tidx], f(dataview[gidx2]))
end
end
end
Expand All @@ -197,28 +188,32 @@ function mapreduce_cuda_kernel!(
return nothing
end

@inline function _get_gidx(tidx, bidx, fidx, effective_blksize, nblk)
return tidx +
(bidx - 1) * effective_blksize +
(fidx - 1) * effective_blksize * nblk
@inline function _get_gidx(tidx, bidx, effective_blksize)
return tidx + (bidx - 1) * effective_blksize
end
# for VF DataLayout
@inline function _get_dims(pdata::AbstractArray{FT, 2}) where {FT}
(Nv, Nf) = size(pdata)
return (Nv, 1, Nf, 1)
end
@inline _dataview(pdata::AbstractArray{FT, 2}, fidx) where {FT} =
view(pdata, :, fidx:fidx)

# for IJFH DataLayout
@inline function _get_dims(pdata::AbstractArray{FT, 4}) where {FT}
(Nij, _, Nf, Nh) = size(pdata)
return (1, Nij, Nf, Nh)
end
@inline _dataview(pdata::AbstractArray{FT, 4}, fidx) where {FT} =
view(pdata, :, :, fidx:fidx, :)

# for VIJFH DataLayout
@inline function _get_dims(pdata::AbstractArray{FT, 5}) where {FT}
(Nv, Nij, _, Nf, Nh) = size(pdata)
return (Nv, Nij, Nf, Nh)
end
@inline _dataview(pdata::AbstractArray{FT, 5}, fidx) where {FT} =
view(pdata, :, :, :, fidx:fidx, :)

@inline function _get_idxs(Nv, Nij, Nf, Nh, fidx, gidx)
hidx = cld(gidx, Nv * Nij * Nij * Nf)
Expand Down
38 changes: 38 additions & 0 deletions test/Fields/reduction_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,41 @@ end

@test LinearAlgebra.norm(Yc, Inf) LinearAlgebra.norm(Yc_cpu, Inf)
end

@testset "test cuda reduction on a field with multiple variables" begin
comms_ctx = ClimaComms.SingletonCommsContext(ClimaComms.device())
comms_ctx_cpu =
ClimaComms.SingletonCommsContext(ClimaComms.CPUSingleThreaded())
FT = Float64
println("running multi-variable reduction test on $(comms_ctx.device)")
domain = Domains.RectangleDomain(
Domains.IntervalDomain(
Geometry.XPoint{FT}(0.0),
Geometry.XPoint{FT}(1.0);
periodic = false,
boundary_names = (:west, :east),
),
Domains.IntervalDomain(
Geometry.YPoint{FT}(0.0),
Geometry.YPoint{FT}(1.0);
periodic = false,
boundary_names = (:south, :north),
),
)
mesh = Meshes.RectilinearMesh(domain, 3, 3)
topology = Topologies.Topology2D(comms_ctx, mesh)
topology_cpu = Topologies.Topology2D(comms_ctx_cpu, mesh)
quad = Spaces.Quadratures.GLL{5}()
space = Spaces.SpectralElementSpace2D(topology, quad)
space_cpu = Spaces.SpectralElementSpace2D(topology_cpu, quad)
coords = Fields.coordinate_field(space)
coords_cpu = Fields.coordinate_field(space_cpu)

q₀(coords, x_scale, y_scale) =
(x = x_scale * coords.x, y = y_scale * coords.y)

q = @. q₀(coords, 1.2, 1.5)
q_cpu = @. q₀(coords_cpu, 1.2, 1.5)

@test [sum(q)...] [sum(q_cpu)...]
end

0 comments on commit c76e128

Please sign in to comment.