diff --git a/src/Fields/mapreduce_cuda.jl b/src/Fields/mapreduce_cuda.jl index 03ab059be0..035648dbb8 100644 --- a/src/Fields/mapreduce_cuda.jl +++ b/src/Fields/mapreduce_cuda.jl @@ -150,42 +150,33 @@ function mapreduce_cuda_kernel!( tidx = threadIdx().x bidx = blockIdx().x fidx = blockIdx().y + dataview = _dataview(pdata, fidx) effective_blksize = blksize * (n_ops_on_load + 1) - gidx = _get_gidx(tidx, bidx, fidx, effective_blksize, nblk) + gidx = _get_gidx(tidx, bidx, effective_blksize) reduction = CUDA.CuStaticSharedArray(T, shmemsize) reduction[tidx] = 0 - (Nv, Nij, Nf, Nh) = _get_dims(pdata) + (Nv, Nij, Nf, Nh) = _get_dims(dataview) nitems = Nv * Nij * Nij * Nf * Nh # load shmem if gidx ≤ nitems if weighting - reduction[tidx] = f(pdata[gidx]) * pwt[gidx] + reduction[tidx] = f(dataview[gidx]) * pwt[gidx] for n_ops in 1:n_ops_on_load - gidx2 = _get_gidx( - tidx + blksize * n_ops, - bidx, - fidx, - effective_blksize, - nblk, - ) + gidx2 = + _get_gidx(tidx + blksize * n_ops, bidx, effective_blksize) if gidx2 ≤ nitems reduction[tidx] = - op(reduction[tidx], f(pdata[gidx2]) * pwt[gidx2]) + op(reduction[tidx], f(dataview[gidx2]) * pwt[gidx2]) end end else - reduction[tidx] = f(pdata[gidx]) + reduction[tidx] = f(dataview[gidx]) for n_ops in 1:n_ops_on_load - gidx2 = _get_gidx( - tidx + blksize * n_ops, - bidx, - fidx, - effective_blksize, - nblk, - ) + gidx2 = + _get_gidx(tidx + blksize * n_ops, bidx, effective_blksize) if gidx2 ≤ nitems - reduction[tidx] = op(reduction[tidx], f(pdata[gidx2])) + reduction[tidx] = op(reduction[tidx], f(dataview[gidx2])) end end end @@ -197,28 +188,32 @@ function mapreduce_cuda_kernel!( return nothing end -@inline function _get_gidx(tidx, bidx, fidx, effective_blksize, nblk) - return tidx + - (bidx - 1) * effective_blksize + - (fidx - 1) * effective_blksize * nblk +@inline function _get_gidx(tidx, bidx, effective_blksize) + return tidx + (bidx - 1) * effective_blksize end # for VF DataLayout @inline function _get_dims(pdata::AbstractArray{FT, 2}) where {FT} (Nv, Nf) = size(pdata) return (Nv, 1, Nf, 1) end +@inline _dataview(pdata::AbstractArray{FT, 2}, fidx) where {FT} = + view(pdata, :, fidx:fidx) # for IJFH DataLayout @inline function _get_dims(pdata::AbstractArray{FT, 4}) where {FT} (Nij, _, Nf, Nh) = size(pdata) return (1, Nij, Nf, Nh) end +@inline _dataview(pdata::AbstractArray{FT, 4}, fidx) where {FT} = + view(pdata, :, :, fidx:fidx, :) # for VIJFH DataLayout @inline function _get_dims(pdata::AbstractArray{FT, 5}) where {FT} (Nv, Nij, _, Nf, Nh) = size(pdata) return (Nv, Nij, Nf, Nh) end +@inline _dataview(pdata::AbstractArray{FT, 5}, fidx) where {FT} = + view(pdata, :, :, :, fidx:fidx, :) @inline function _get_idxs(Nv, Nij, Nf, Nh, fidx, gidx) hidx = cld(gidx, Nv * Nij * Nij * Nf) diff --git a/test/Fields/reduction_cuda.jl b/test/Fields/reduction_cuda.jl index 0c34076df4..cc26a79911 100644 --- a/test/Fields/reduction_cuda.jl +++ b/test/Fields/reduction_cuda.jl @@ -252,3 +252,41 @@ end @test LinearAlgebra.norm(Yc, Inf) ≈ LinearAlgebra.norm(Yc_cpu, Inf) end + +@testset "test cuda reduction on a field with multiple variables" begin + comms_ctx = ClimaComms.SingletonCommsContext(ClimaComms.device()) + comms_ctx_cpu = + ClimaComms.SingletonCommsContext(ClimaComms.CPUSingleThreaded()) + FT = Float64 + println("running multi-variable reduction test on $(comms_ctx.device)") + domain = Domains.RectangleDomain( + Domains.IntervalDomain( + Geometry.XPoint{FT}(0.0), + Geometry.XPoint{FT}(1.0); + periodic = false, + boundary_names = (:west, :east), + ), + Domains.IntervalDomain( + Geometry.YPoint{FT}(0.0), + Geometry.YPoint{FT}(1.0); + periodic = false, + boundary_names = (:south, :north), + ), + ) + mesh = Meshes.RectilinearMesh(domain, 3, 3) + topology = Topologies.Topology2D(comms_ctx, mesh) + topology_cpu = Topologies.Topology2D(comms_ctx_cpu, mesh) + quad = Spaces.Quadratures.GLL{5}() + space = Spaces.SpectralElementSpace2D(topology, quad) + space_cpu = Spaces.SpectralElementSpace2D(topology_cpu, quad) + coords = Fields.coordinate_field(space) + coords_cpu = Fields.coordinate_field(space_cpu) + + q₀(coords, x_scale, y_scale) = + (x = x_scale * coords.x, y = y_scale * coords.y) + + q = @. q₀(coords, 1.2, 1.5) + q_cpu = @. q₀(coords_cpu, 1.2, 1.5) + + @test [sum(q)...] ≈ [sum(q_cpu)...] +end