Skip to content

Commit

Permalink
Merge pull request #2106 from CliMA/gb/fix_cuda_reductions
Browse files Browse the repository at this point in the history
Fix correctness in cuda_mapreduce
  • Loading branch information
Sbozzolo authored Dec 21, 2024
2 parents b4a04d8 + 8cdf3f3 commit 6539b89
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ main

### ![][badge-🐛bugfix] Bug fixes

- Fixed writing/reading purely vertical spaces
- Fixed writing/reading purely vertical spaces. PR [2102](https://github.com/CliMA/ClimaCore.jl/pull/2102)
- Fixed correctness bug in reductions on GPUs. PR [2106](https://github.com/CliMA/ClimaCore.jl/pull/2106)

v0.14.20
--------
Expand Down
37 changes: 36 additions & 1 deletion ext/cuda/data_layouts_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,35 @@ function mapreduce_cuda(
weighted_jacobian = OnesArray(parent(data)),
opargs...,
)
# This function implements the following parallel reduction algorithm:
#
# Each thread in each blocks processes multiple data points at the same time
# (n_ops_on_load) each and we perform a block-wise reduction, with each
# block writing to an array of (block-)shared memory. This array has the
# same size as the block, ie, it is as long as many threads are available.
# Processing multiple points means that we apply the reduction to the point
# with index reduction[thread_index] = f(thread_index, thread_index +
# OFFSET), with various OFFSETS that depend on `n_ops_on_load` and block
# size.
#
# For the purpose of indexing, this is equivalent to having larger blocks
# with size effective_blksize = blksize * (n_ops_on_load + 1).
#
#
# After this operation, we have reduced all the data by a factor of
# 1/n_ops_on_load and have results in various arrays `reduction` (one per
# block)
#
# Once we have all the blocks reduced, we perform a tree reduction within
# the block and "move" the reduced value to the first element of the array.
# In this, one of the things to watch out for is that the last block might
# not necessarily have all threads doing work, so we have to be careful to
# not include data in `reduction` that did not have corresponding work.
# Threads of index 1 will write that array into an output array.
#
# The output array has size nblocks, so we do another round of reduction,
# but this time we put each Field in a different block.

S = eltype(data)
pdata = parent(data)
T = eltype(pdata)
Expand Down Expand Up @@ -112,7 +141,13 @@ function mapreduce_cuda_kernel!(
end
end
sync_threads()
_cuda_intrablock_reduce!(op, reduction, tidx, blksize)

# The last block might not have enough threads to fill `reduction`, so some
# of its elements might still have the value at initialization.
blksize_for_reduction =
min(blksize, nitems - effective_blksize * (bidx - 1))

_cuda_intrablock_reduce!(op, reduction, tidx, blksize_for_reduction)

tidx == 1 && (reduce_cuda[bidx, fidx] = reduction[1])
return nothing
Expand Down
32 changes: 32 additions & 0 deletions test/DataLayouts/unit_mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,35 @@ end
# data = DataLayouts.IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nv,Nh); test_mapreduce_2!(context, data_view(data)) # TODO: test
# data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_mapreduce_2!(context, data_view(data)) # TODO: test
end

@testset "mapreduce with space with some non-round blocks" begin
# https://github.com/CliMA/ClimaCore.jl/issues/2097
space = ClimaCore.CommonSpaces.RectangleXYSpace(;
x_min = 0,
x_max = 1,
y_min = 0,
y_max = 1,
periodic_x = false,
periodic_y = false,
n_quad_points = 4,
x_elem = 129,
y_elem = 129,
)
@test minimum(ones(space)) == 1

if ClimaComms.context isa ClimaComms.SingletonCommsContext
# Less than 256 threads
space = ClimaCore.CommonSpaces.RectangleXYSpace(;
x_min = 0,
x_max = 1,
y_min = 0,
y_max = 1,
periodic_x = false,
periodic_y = false,
n_quad_points = 2,
x_elem = 2,
y_elem = 2,
)
@test minimum(ones(space)) == 1
end
end

0 comments on commit 6539b89

Please sign in to comment.