From b827c91f5f5255b9da94cd6611282ab4d9bb8218 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Tue, 17 Dec 2024 16:53:56 -0800 Subject: [PATCH] Fix correctness in cuda_mapreduce `cuda_mapreduce` was not working correctly with certain spaces. Why was this happening? I added a comment to describe the algorithm in the commit. In a nutshell, the algorithm was not taking into account the fact that the final block is not completely filled with points to process. Therefore, the reduction included some elements that did not contain real points (but the value 0). --- ext/cuda/data_layouts_mapreduce.jl | 36 +++++++++++++++++++++++++++++- test/DataLayouts/unit_mapreduce.jl | 30 +++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/ext/cuda/data_layouts_mapreduce.jl b/ext/cuda/data_layouts_mapreduce.jl index 272216f7c7..4272e972a2 100644 --- a/ext/cuda/data_layouts_mapreduce.jl +++ b/ext/cuda/data_layouts_mapreduce.jl @@ -31,6 +31,34 @@ function mapreduce_cuda( weighted_jacobian = OnesArray(parent(data)), opargs..., ) + # This function implements the following parallel reduction algorithm: + # + # Blocks processes multiple data points at the same time (n_ops_on_load) + # each and we perform a block-wise reduction, with each block writing to an + # array of (block-)shared memory. This array has the same size as the block, + # ie, it is as long as many threads are available. Processing multiple + # points means that we apply the reduction to the point with index + # reduction[thread_index] = f(thread_index, thread_index + OFFSET), with + # various OFFSETS that depend on `n_ops_on_load` and block size. + # + # For the purpose of indexing, this is equivalent to having larger blocks + # with size effective_blksize = blksize * (n_ops_on_load + 1). + # + # + # After this operation, we have reduced all the data by a factor of + # 1/n_ops_on_load and have results in various arrays `reduction` (one per + # block) + # + # Once we have all the blocks reduced, we perform a tree reduction within + # the block and "move" the reduced value to the first element of the array. + # In this, one of the things to watch out for is that the last block might + # not necessarily have all threads doing work, so we have to be careful to + # not include data in `reduction` that did not have corresponding work. + # Threads of index 1 will write that array into an output array. + # + # The output array has size nblocks, so we do another round of reduction, + # but this time we put each Field in a different block. + S = eltype(data) pdata = parent(data) T = eltype(pdata) @@ -112,7 +140,13 @@ function mapreduce_cuda_kernel!( end end sync_threads() - _cuda_intrablock_reduce!(op, reduction, tidx, blksize) + + # The last block might not have enough threads to fill `reduction`, so some + # of its elements might still have the value at initialization. + blksize_for_reduction = + min(blksize, nitems - effective_blksize * (bidx - 1)) + + _cuda_intrablock_reduce!(op, reduction, tidx, blksize_for_reduction) tidx == 1 && (reduce_cuda[bidx, fidx] = reduction[1]) return nothing diff --git a/test/DataLayouts/unit_mapreduce.jl b/test/DataLayouts/unit_mapreduce.jl index a0bc4d33f6..2d5cc368bd 100644 --- a/test/DataLayouts/unit_mapreduce.jl +++ b/test/DataLayouts/unit_mapreduce.jl @@ -162,3 +162,33 @@ end # data = DataLayouts.IJKFVH{S}(ArrayType{FT}, zeros; Nij,Nk,Nv,Nh); test_mapreduce_2!(context, data_view(data)) # TODO: test # data = DataLayouts.IH1JH2{S}(ArrayType{FT}, zeros; Nij); test_mapreduce_2!(context, data_view(data)) # TODO: test end + +@testset "mapreduce with space with some non-round blocks" begin + # https://github.com/CliMA/ClimaCore.jl/issues/2097 + space = ClimaCore.CommonSpaces.RectangleXYSpace(; + x_min = 0, + x_max = 1, + y_min = 0, + y_max = 1, + periodic_x = false, + periodic_y = false, + n_quad_points = 4, + x_elem = 129, + y_elem = 129, + ) + @test minimum(ones(space)) == 1 + + # Less than 256 threads + space = ClimaCore.CommonSpaces.RectangleXYSpace(; + x_min = 0, + x_max = 1, + y_min = 0, + y_max = 1, + periodic_x = false, + periodic_y = false, + n_quad_points = 2, + x_elem = 1, + y_elem = 1, + ) + @test minimum(ones(space)) == 1 +end