diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index 72178f8519..3cb61f6225 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -78,6 +78,11 @@ # process-local interpolate points in the correct shape with respect to the global # interpolation (and where to collect results) # +# To process multiple Fields at the same time, some of the scratch spaces gain an extra +# dimension (buffer_length). With this extra dimension, we can batch the work and process up +# to buffer_length fields at the same time. This reduces the number of kernel launches and +# MPI calls. +# # For GPU runs, we store all the vectors as CuArrays on the GPU. """ @@ -124,6 +129,7 @@ struct Remapper{ T8 <: AbstractArray, T9 <: AbstractArray, T10 <: AbstractArray, + T11 <: Union{Tuple{Colon}, Tuple{Colon, Colon}, Tuple{Colon, Colon, Colon}}, } comms_ctx::CC @@ -170,25 +176,32 @@ struct Remapper{ # Scratch space where we save the process-local interpolated values. We keep overwriting # this to avoid extra allocations. This is a linear array with the same length as - # local_horiz_indices. + # local_horiz_indices with an extra dimension of length buffer_length added. _local_interpolated_values::T8 # Scratch space where we save the process-local field value. We keep overwriting this to # avoid extra allocations. Ideally, we wouldn't need this and we would use views for - # everything. This has dimensions (Nq) or (Nq, Nq) depending if the horizontal space is - # 1D or 2D. + # everything. This has dimensions (buffer_length, Nq) or (buffer_length, Nq, Nq) + # depending if the horizontal space is 1D or 2D. _field_values::T9 # Storage area where the interpolated values are saved. This is meaningful only for the - # root process and gets filled by a interpolate call. This has dimensions (H, V), where - # H is the size of target_hcoords and V of target_zcoords. In other words, this is the - # expected output array. + # root process and gets filled by a interpolate call. This has dimensions + # (buffer_length, H, V), where H is the size of target_hcoords and V of target_zcoords. + # In other words, this is the expected output array. _interpolated_values::T10 + + # Maximum number of Fields that can be interpolated at any given time + buffer_length::Int + + # A tuple of Colons (1, 2, or 3), used to more easily get views into arrays with unknown + # dimension (1-3D) + colons::T11 end """ - Remapper(space, target_hcoords, target_zcoords) - Remapper(space, target_hcoords) + Remapper(space, target_hcoords, target_zcoords; buffer_length = 1) + Remapper(space, target_hcoords; buffer_length = 1) Return a `Remapper` responsible for interpolating any `Field` defined on the given `space` to the Cartesian product of `target_hcoords` with `target_zcoords`. @@ -199,11 +212,21 @@ The `Remapper` is designed to not be tied to any particular `Field`. You can use `Remapper` for any `Field` as long as they are all defined on the same `topology`. `Remapper` is the main argument to the `interpolate` function. + +Keyword arguments +================= + +`buffer_length` is size of the internal buffer in the Remapper to store intermediate values +for interpolation. Effectively, this controls how many fields can be remapped simultaneously +in `interpolate`. When more fields than `buffer_length` are passed, the remapper will batch +the work in sizes of `buffer_length`. + """ function Remapper( space::Spaces.AbstractSpace, target_hcoords::AbstractArray, - target_zcoords::Union{AbstractArray, Nothing}, + target_zcoords::Union{AbstractArray, Nothing}; + buffer_length::Int = 1, ) comms_ctx = ClimaComms.context(space) @@ -277,7 +300,7 @@ function Remapper( local_horiz_indices = ArrayType(local_horiz_indices) field_values_size = ntuple(_ -> Nq, num_hdims) - field_values = ArrayType(zeros(FT, field_values_size)) + field_values = ArrayType(zeros(FT, field_values_size...)) # We represent interpolation onto an horizontal slab as an empty list of zcoords if isnothing(target_zcoords) || isempty(target_zcoords) @@ -285,9 +308,11 @@ function Remapper( vert_interpolation_weights = nothing vert_bounding_indices = nothing local_interpolated_values = - ArrayType(zeros(FT, size(local_horiz_indices))) - interpolated_values = - ArrayType(zeros(FT, size(local_target_hcoords_bitmask))) + ArrayType(zeros(FT, (buffer_length, size(local_horiz_indices)...))) + interpolated_values = ArrayType( + zeros(FT, (buffer_length, size(local_target_hcoords_bitmask)...)), + ) + num_dims = num_hdims + 1 else vert_interpolation_weights = ArrayType(vertical_interpolation_weights(space, target_zcoords)) @@ -297,16 +322,33 @@ function Remapper( # We have to add one extra dimension with respect to the bitmask/local_horiz_indices # because we are going to store the values for the columns local_interpolated_values = ArrayType( - zeros(FT, (size(local_horiz_indices)..., length(target_zcoords))), + zeros( + FT, + ( + buffer_length, + size(local_horiz_indices)..., + length(target_zcoords), + ), + ), ) interpolated_values = ArrayType( zeros( FT, - (size(local_target_hcoords_bitmask)..., length(target_zcoords)), + ( + buffer_length, + size(local_target_hcoords_bitmask)..., + length(target_zcoords), + ), ), ) + num_dims = num_hdims + 1 end + # We don't know how many dimensions an array might have, so we define a colons object + # that we can use to index with array[colons...] + + colons = ntuple(_ -> Colon(), num_dims) + return Remapper( comms_ctx, space, @@ -320,22 +362,27 @@ function Remapper( local_interpolated_values, field_values, interpolated_values, + buffer_length, + colons, ) end -Remapper(space::Spaces.AbstractSpace, target_hcoords::AbstractArray) = - Remapper(space, target_hcoords, nothing) +Remapper( + space::Spaces.AbstractSpace, + target_hcoords::AbstractArray; + buffer_length::Int = 1, +) = Remapper(space, target_hcoords, nothing; buffer_length) """ _set_interpolated_values!(remapper, field) -Change the local state of `remapper` by performing interpolation of `Fields` on the vertical +Change the local state of `remapper` by performing interpolation of `fields` on the vertical and horizontal points. """ -function _set_interpolated_values!(remapper::Remapper, field::Fields.Field) +function _set_interpolated_values!(remapper::Remapper, fields) _set_interpolated_values!( remapper._local_interpolated_values, - field, + fields, remapper._field_values, remapper.local_horiz_indices, remapper.local_horiz_interpolation_weights, @@ -344,52 +391,55 @@ function _set_interpolated_values!(remapper::Remapper, field::Fields.Field) ) end +# CPU, 3D case function set_interpolated_values_cpu_kernel!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, (I1, I2)::NTuple{2}, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) - - # Reading values from field_values is expensive, so we try to limit the number of reads. We can do - # this because multiple target points might be all contained in the same element. - prev_vindex, prev_lidx = -1, -1 - @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) - (v_lo, v_hi) = vert_bounding_indices[vindex] - for (out_index, h) in enumerate(local_horiz_indices) - # If we are no longer in the same element, read the field values again - if prev_lidx != h || prev_vindex != vindex - for j in 1:Nq, i in 1:Nq - scratch_field_values[i, j] = ( - A * field_values[i, j, nothing, v_lo, h] + - B * field_values[i, j, nothing, v_hi, h] - ) + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + + # Reading values from field_values is expensive, so we try to limit the number of reads. We can do + # this because multiple target points might be all contained in the same element. + prev_vindex, prev_lidx = -1, -1 + @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) + (v_lo, v_hi) = vert_bounding_indices[vindex] + for (out_index, h) in enumerate(local_horiz_indices) + # If we are no longer in the same element, read the field values again + if prev_lidx != h || prev_vindex != vindex + for j in 1:Nq, i in 1:Nq + scratch_field_values[i, j] = ( + A * field_values[i, j, nothing, v_lo, h] + + B * field_values[i, j, nothing, v_hi, h] + ) + end + prev_vindex, prev_lidx = vindex, h end - prev_vindex, prev_lidx = vindex, h - end - tmp = zero(FT) + tmp = zero(FT) - for j in 1:Nq, i in 1:Nq - tmp += - I1[out_index, i] * - I2[out_index, j] * - scratch_field_values[i, j] + for j in 1:Nq, i in 1:Nq + tmp += + I1[out_index, i] * + I2[out_index, j] * + scratch_field_values[i, j] + end + out[field_index, out_index, vindex] = tmp end - out[out_index, vindex] = tmp end end end +# GPU, 3D case function set_interpolated_values_kernel!( out::AbstractArray, (I1, I2)::NTuple{2}, @@ -398,32 +448,46 @@ function set_interpolated_values_kernel!( vert_bounding_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_vert = length(vert_bounding_indices) + num_fields = length(field_values) - hindex = blockIdx().x - vindex = threadIdx().x - index = vindex + (hindex - 1) * blockDim().x - index > length(out) && return nothing + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - h = local_horiz_indices[hindex] - v_lo, v_hi = vert_bounding_indices[vindex] - A, B = vert_interpolation_weights[vindex] + totalThreadsX = gridDim().x * blockDim().x + totalThreadsY = gridDim().y * blockDim().y + totalThreadsZ = gridDim().z * blockDim().z _, Nq = size(I1) - out[hindex, vindex] = 0 - for j in 1:Nq, i in 1:Nq - out[hindex, vindex] += - I1[hindex, i] * - I2[hindex, j] * - ( - A * field_values[i, j, nothing, v_lo, h] + - B * field_values[i, j, nothing, v_hi, h] - ) + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for j in vindex:totalThreadsY:num_vert + v_lo, v_hi = vert_bounding_indices[j] + A, B = vert_interpolation_weights[j] + for k in findex:totalThreadsZ:num_fields + if i ≤ num_horiz && j ≤ num_vert && k ≤ num_fields + out[k, i, j] = 0 + for t in 1:Nq, s in 1:Nq + out[k, i, j] += + I1[i, t] * + I2[i, s] * + ( + A * field_values[k][t, s, nothing, v_lo, h] + + B * field_values[k][t, s, nothing, v_hi, h] + ) + end + end + end + end end - return nothing end +# GPU, 2D case function set_interpolated_values_kernel!( out::AbstractArray, (I,)::NTuple{1}, @@ -432,87 +496,105 @@ function set_interpolated_values_kernel!( vert_bounding_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_vert = length(vert_bounding_indices) + num_fields = length(field_values) - hindex = blockIdx().x - vindex = threadIdx().x - index = vindex + (hindex - 1) * blockDim().x - index > length(out) && return nothing + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - h = local_horiz_indices[hindex] - v_lo, v_hi = vert_bounding_indices[vindex] - A, B = vert_interpolation_weights[vindex] + totalThreadsX = gridDim().x * blockDim().x + totalThreadsY = gridDim().y * blockDim().y + totalThreadsZ = gridDim().z * blockDim().z - _, Nq = size(I) + _, Nq = size(I1) - out[hindex, vindex] = 0 - for i in 1:Nq - out[hindex, vindex] += - I[hindex, i] * ( - A * field_values[i, nothing, nothing, v_lo, h] + - B * field_values[i, nothing, nothing, v_hi, h] - ) + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for j in vindex:totalThreadsY:num_vert + v_lo, v_hi = vert_bounding_indices[j] + A, B = vert_interpolation_weights[j] + for k in findex:totalThreadsZ:num_fields + if i ≤ num_horiz && j ≤ num_vert && k ≤ num_fields + out[k, i, j] = 0 + for t in 1:Nq + out[k, i, j] += + I1[i, t] * + I2[i, s] * + ( + A * + field_values[k][t, nothing, nothing, v_lo, h] + + B * + field_values[k][t, nothing, nothing, v_hi, h] + ) + end + end + end + end end - return nothing end +# CPU, 2D case function set_interpolated_values_cpu_kernel!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, (I,)::NTuple{1}, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + + # Reading values from field_values is expensive, so we try to limit the number of reads. We can do + # this because multiple target points might be all contained in the same element. + prev_vindex, prev_lidx = -1, -1 + @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) + (v_lo, v_hi) = vert_bounding_indices[vindex] + + for (out_index, h) in enumerate(local_horiz_indices) + # If we are no longer in the same element, read the field values again + if prev_lidx != h || prev_vindex != vindex + for i in 1:Nq + scratch_field_values[i] = ( + A * field_values[i, nothing, nothing, v_lo, h] + + B * field_values[i, nothing, nothing, v_hi, h] + ) + end + prev_vindex, prev_lidx = vindex, h + end - # Reading values from field_values is expensive, so we try to limit the number of reads. We can do - # this because multiple target points might be all contained in the same element. - prev_vindex, prev_lidx = -1, -1 - @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) - (v_lo, v_hi) = vert_bounding_indices[vindex] + tmp = zero(FT) - for (out_index, h) in enumerate(local_horiz_indices) - # If we are no longer in the same element, read the field values again - if prev_lidx != h || prev_vindex != vindex for i in 1:Nq - scratch_field_values[i] = ( - A * field_values[i, nothing, nothing, v_lo, h] + - B * field_values[i, nothing, nothing, v_hi, h] - ) + tmp += I[out_index, i] * scratch_field_values[i] end - prev_vindex, prev_lidx = vindex, h + out[field_index, out_index, vindex] = tmp end - - tmp = zero(FT) - - for i in 1:Nq - tmp += I[out_index, i] * scratch_field_values[i] - end - out[out_index, vindex] = tmp end end end function _set_interpolated_values!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, scratch_field_values, local_horiz_indices, interpolation_matrix, vert_interpolation_weights::AbstractArray, vert_bounding_indices::AbstractArray, ) - - field_values = Fields.field_values(field) - - if ClimaComms.device(field) isa ClimaComms.CUDADevice + if ClimaComms.device(first(fields)) isa ClimaComms.CUDADevice + # FIXME: Avoid allocation of tuple + field_values = tuple(map(f -> Fields.field_values(f), fields)...) nblocks, _ = size(interpolation_matrix[1]) nthreads = length(vert_interpolation_weights) @cuda always_inline = true threads = (nthreads) blocks = (nblocks) set_interpolated_values_kernel!( @@ -526,20 +608,20 @@ function _set_interpolated_values!( else set_interpolated_values_cpu_kernel!( out, - field, + fields, interpolation_matrix, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) end end +# Horizontal function _set_interpolated_values!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, _scratch_field_values, local_horiz_indices, local_horiz_interpolation_weights, @@ -547,16 +629,17 @@ function _set_interpolated_values!( ::Nothing, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) hdims = length(local_horiz_interpolation_weights) hdims in (1, 2) || error("Cannot handle $hdims horizontal dimensions") if ClimaComms.device(space) isa ClimaComms.CUDADevice + # FIXME: Avoid allocation of tuple + field_values = tuple(map(f -> Fields.field_values(f), fields)...) nitems = length(out) nthreads, nblocks = Topologies._configure_threadblock(nitems) @cuda always_inline = true threads = (nthreads) blocks = (nblocks) set_interpolated_values_kernel!( @@ -566,20 +649,23 @@ function _set_interpolated_values!( field_values, ) else - for (out_index, h) in enumerate(local_horiz_indices) - out[out_index] = zero(FT) - if hdims == 2 - for j in 1:Nq, i in 1:Nq - out[out_index] += - local_horiz_interpolation_weights[1][out_index, i] * - local_horiz_interpolation_weights[2][out_index, j] * - field_values[i, j, nothing, nothing, h] - end - elseif hdims == 1 - for i in 1:Nq - out[out_index] += - local_horiz_interpolation_weights[1][out_index, i] * - field_values[i, nothing, nothing, nothing, h] + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + for (out_index, h) in enumerate(local_horiz_indices) + out[field_index, out_index] = zero(FT) + if hdims == 2 + for j in 1:Nq, i in 1:Nq + out[field_index, out_index] += + local_horiz_interpolation_weights[1][out_index, i] * + local_horiz_interpolation_weights[2][out_index, j] * + field_values[i, j, nothing, nothing, h] + end + elseif hdims == 1 + for i in 1:Nq + out[field_index, out_index] += + local_horiz_interpolation_weights[1][out_index, i] * + field_values[i, nothing, nothing, nothing, h] + end end end end @@ -592,21 +678,32 @@ function set_interpolated_values_kernel!( local_horiz_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_fields = length(field_values) - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - index > length(out) && return nothing + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z + + totalThreadsX = gridDim().x * blockDim().x + totalThreadsZ = gridDim().z * blockDim().z - h = local_horiz_indices[index] _, Nq = size(I1) - out[index] = 0 - for j in 1:Nq, i in 1:Nq - out[index] += - I1[index, i] * - I2[index, j] * - field_values[i, j, nothing, nothing, h] + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for k in findex:totalThreadsZ:num_fields + if i ≤ num_horiz && k ≤ num_fields + out[k, i] = 0 + for t in 1:Nq, s in 1:Nq + out[k, i] += + I1[i, t] * + I2[i, s] * + field_values[k][t, s, nothing, nothing, h] + end + end + end end - return nothing end @@ -616,47 +713,56 @@ function set_interpolated_values_kernel!( local_horiz_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_fields = length(field_values) + + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - index > length(out) && return nothing + totalThreadsX = gridDim().x * blockDim().x + totalThreadsZ = gridDim().z * blockDim().z - h = local_horiz_indices[index] _, Nq = size(I) - out[index] = 0 - for i in 1:Nq - out[index] += - I[index, i] * field_values[i, nothing, nothing, nothing, h] + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for k in findex:totalThreadsZ:num_fields + if i ≤ num_horiz && k ≤ num_fields + out[k, i] = 0 + for t in 1:Nq, s in 1:Nq + out[k, i] += + I[i, i] * + field_values[k][t, nothing, nothing, nothing, h] + end + end + end end - return nothing end - """ - _apply_mpi_bitmask!(remapper::Remapper) + _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int) Change to local (private) state of the `remapper` by applying the MPI bitmask and reconstructing the correct shape for the interpolated values. Internally, `remapper` performs interpolation on a flat list of points, this function moves points around according to MPI-ownership and the expected output shape. + +`num_fields` is the number of fields that have been processed and have to be moved in the +`interpolated_values`. We assume that it is always the first `num_fields` that have to be moved. """ -function _apply_mpi_bitmask!(remapper::Remapper) - if isnothing(remapper.target_zcoords) - # _interpolated_values[remapper.local_target_hcoords_bitmask] returns a view on - # space we want to write on - remapper._interpolated_values[remapper.local_target_hcoords_bitmask] .= - remapper._local_interpolated_values - else - # interpolated_values is an array of arrays properly ordered according to the bitmask - - # _interpolated_values[remapper.local_target_hcoords_bitmask, :] returns a - # view on space we want to write on - remapper._interpolated_values[ - remapper.local_target_hcoords_bitmask, - :, - ] .= remapper._local_interpolated_values - end +function _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int) + view( + remapper._interpolated_values, + 1:num_fields, + remapper.local_target_hcoords_bitmask, + :, + ) .= view( + remapper._local_interpolated_values, + 1:num_fields, + remapper.colons..., + ) end """ @@ -670,38 +776,104 @@ function _reset_interpolated_values!(remapper::Remapper) end """ - _collect_and_return_interpolated_values!(remapper::Remapper) + _collect_and_return_interpolated_values!(remapper::Remapper, + num_fields::Int) Perform an MPI call to aggregate the interpolated points from all the MPI processes and save the result in the local state of the `remapper`. Only the root process will return the interpolated data. `_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays. -""" -function _collect_and_return_interpolated_values!(remapper::Remapper) - ClimaComms.reduce!(remapper.comms_ctx, remapper._interpolated_values, +) - return ClimaComms.iamroot(remapper.comms_ctx) ? - Array(remapper._interpolated_values) : nothing -end -function _collect_interpolated_values!(dest, remapper::Remapper) - # MPI.reduce! seems to behave nicely with respect to CPU/GPU. In particular, - # if the destination is on the CPU, but the source is on the GPU, the values - # are automatically moved. +`num_fields` is the number of fields that have been interpolated in this batch. +""" +function _collect_and_return_interpolated_values!( + remapper::Remapper, + num_fields::Int, +) + output_array = + Array(remapper._interpolated_values[1:num_fields, remapper.colons...]) ClimaComms.reduce!( remapper.comms_ctx, - remapper._interpolated_values, - dest, + view(remapper._interpolated_values, 1:num_fields, remapper.colons...), + output_array, +, ) + return ClimaComms.iamroot(remapper.comms_ctx) ? output_array : nothing +end + +function _collect_interpolated_values!( + dest, + remapper::Remapper, + index_field_begin::Int, + index_field_end::Int, +) + + # In interpolate! we check that dest is consistent with what we have in the remapper. We + # can use this fact here to figure out if we are remapping only one field or more. + only_one_field = size(dest) == size(remapper._interpolated_values)[2:end] + field_selector = only_one_field ? () : (index_field_begin:index_field_end,) + + num_fields = 1 + index_field_end - index_field_begin + if (ClimaComms.device(remapper.space) isa ClimaComms.CUDADevice) + num_fields == remapper.buffer_length || error("Operation not supported") + # CUDA.jl does not support views very well at the moment. We can only work with + # num_fields = buffer_length + if only_one_field + ClimaComms.reduce!( + remapper.comms_ctx, + remapper._interpolated_values[1, remapper.colons...], + dest, + +, + ) + else + ClimaComms.reduce!( + remapper.comms_ctx, + remapper._interpolated_values, + dest, + +, + ) + end + else + # MPI.reduce! seems to behave nicely with respect to CPU/GPU. In particular, + # if the destination is on the CPU, but the source is on the GPU, the values + # are automatically moved. + ClimaComms.reduce!( + remapper.comms_ctx, + view( + remapper._interpolated_values, + 1:num_fields, + remapper.colons..., + ), + view(dest, field_selector..., remapper.colons...), + +, + ) + end return nothing end """ - interpolate(remapper::Remapper, field) - interpolate!(dest, remapper::Remapper, field) + batched_ranges(num_fields, buffer_length) -Interpolate the given `field` as prescribed by `remapper`. +Partition the indices from 1 to num_fields in such a way that no range is larger than +buffer_length. +""" +function batched_ranges(num_fields, buffer_length) + return [ + (i * buffer_length + 1):(min((i + 1) * buffer_length, num_fields)) for + i in 0:(div((num_fields - 1), buffer_length)) + ] +end + +""" + interpolate(remapper::Remapper, fields) + interpolate!(dest, remapper::Remapper, fields) + +Interpolate the given `field`(s) as prescribed by `remapper`. + +The optimal number of fields passed is the `buffer_length` of the `remapper`. If +more fields are passed, the `remapper` will batch work with size up to its +`buffer_length`. This call mutates the internal (private) state of the `remapper`. @@ -734,53 +906,105 @@ remapper = Remapper(space, hcoords, zcoords) int1 = interpolate(remapper, field1) int2 = interpolate(remapper, field2) + +# Or +int12 = interpolate(remapper, [field1, field2]) +# With int1 = int12[1, :, :, :] ``` """ -function interpolate(remapper::Remapper, field::T) where {T <: Fields.Field} +function interpolate(remapper::Remapper, fields) - axes(field) == remapper.space || - error("Field is defined on a different space than remapper") + only_one_field = fields isa Fields.Field + if only_one_field + fields = [fields] + end - # Reset interpolated_values. This is needed because we collect distributed results with - # a + reduction. - _reset_interpolated_values!(remapper) - # Perform the interpolations (horizontal and vertical) - _set_interpolated_values!(remapper, field) - # Reshape the output so that it is a nice grid. - _apply_mpi_bitmask!(remapper) - # Finally, we have to send all the _interpolated_values to root and sum them up to - # obtain the final answer. Only the root will contain something useful. This also moves - # the data off the GPU - return _collect_and_return_interpolated_values!(remapper) + for field in fields + axes(field) == remapper.space || + error("Field is defined on a different space than remapper") + end + + index_field_begin, index_field_end = + 1, min(length(fields), remapper.buffer_length) + + # Partition the indices in such a way that nothing is larger than + # buffer_length + index_ranges = batched_ranges(length(fields), remapper.buffer_length) + + interpolated_values = mapreduce(vcat, index_ranges) do range + num_fields = length(range) + + # Reset interpolated_values. This is needed because we collect distributed results + # with a + reduction. + _reset_interpolated_values!(remapper) + # Perform the interpolations (horizontal and vertical) + _set_interpolated_values!( + remapper, + view(fields, index_field_begin:index_field_end), + ) + # Reshape the output so that it is a nice grid. + _apply_mpi_bitmask!(remapper, num_fields) + # Finally, we have to send all the _interpolated_values to root and sum them up to + # obtain the final answer. Only the root will contain something useful. This also + # moves the data off the GPU + return _collect_and_return_interpolated_values!(remapper, num_fields) + end + + return only_one_field ? interpolated_values[begin, remapper.colons...] : + interpolated_values end +# dest has to be allowed to be nothing because interpolation happens only on the root +# process function interpolate!( dest::Union{Nothing, <:AbstractArray}, remapper::Remapper, - field::T, -) where {T <: Fields.Field} - - axes(field) == remapper.space || - error("Field is defined on a different space than remapper") + fields, +) + only_one_field = fields isa Fields.Field + if only_one_field + fields = [fields] + end if !isnothing(dest) # !isnothing(dest) means that this is the root process, in this case, the size have - # to match - size(dest) == size(remapper._interpolated_values) || error( + # to match (ignoring the buffer_length) + dest_size = only_one_field ? size(dest) : size(dest)[2:end] + + dest_size == size(remapper._interpolated_values)[2:end] || error( "Destination array is not compatible with remapper (size mismatch)", ) end + index_field_begin, index_field_end = + 1, min(length(fields), remapper.buffer_length) + + while true + num_fields = 1 + index_field_end - index_field_begin + + # Reset interpolated_values. This is needed because we collect distributed results + # with a + reduction. + _reset_interpolated_values!(remapper) + # Perform the interpolations (horizontal and vertical) + _set_interpolated_values!( + remapper, + view(fields, index_field_begin:index_field_end), + ) + # Reshape the output so that it is a nice grid. + _apply_mpi_bitmask!(remapper, num_fields) + # Finally, we have to send all the _interpolated_values to root and sum them up to + # obtain the final answer. Only the root will contain something useful. This also + # moves the data off the GPU + _collect_interpolated_values!( + dest, + remapper, + index_field_begin, + index_field_end, + ) - # Reset interpolated_values. This is needed because we collect distributed results with - # a + reduction. - _reset_interpolated_values!(remapper) - # Perform the interpolations (horizontal and vertical) - _set_interpolated_values!(remapper, field) - # Reshape the output so that it is a nice grid. - _apply_mpi_bitmask!(remapper) - # Finally, we have to send all the _interpolated_values to root and sum them - # up to obtain the final answer. This also moves the data off the GPU. The - # output is written to the given destination - _collect_interpolated_values!(dest, remapper) + index_field_end != length(fields) || break + index_field_begin = index_field_begin + remapper.buffer_length + index_field_end = + min(length(fields), index_field_end + remapper.buffer_length) + end return nothing end diff --git a/test/Remapping/distributed_remapping.jl b/test/Remapping/distributed_remapping.jl index 1e7bd7d39b..3fcb256954 100644 --- a/test/Remapping/distributed_remapping.jl +++ b/test/Remapping/distributed_remapping.jl @@ -25,7 +25,17 @@ atexit() do global_logger(prev_logger) end -if !(device isa ClimaComms.CUDADevice) +@testset "Utils" begin + # batched_ranges(num_fields, buffer_length) + @test Remapping.batched_ranges(1, 1) == [1:1] + @test Remapping.batched_ranges(1, 2) == [1:1] + @test Remapping.batched_ranges(2, 2) == [1:2] + @test Remapping.batched_ranges(3, 2) == [1:2, 3:3] +end + +on_gpu = device isa ClimaComms.CUDADevice + +if !on_gpu @testset "2D extruded" begin vertdomain = Domains.IntervalDomain( Geometry.ZPoint(0.0), @@ -63,7 +73,12 @@ if !(device isa ClimaComms.CUDADevice) hcoords = [Geometry.XPoint(x) for x in xpts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = Remapping.Remapper( + hv_center_space, + hcoords, + zcoords, + buffer_length = 2, + ) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) @@ -79,6 +94,46 @@ if !(device isa ClimaComms.CUDADevice) @test interp_z[:, end] ≈ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts] end + + # Remapping two fields + interp_xx = Remapping.interpolate(remapper, [coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x ≈ interp_xx[1, :, :, :] + @test interp_x ≈ interp_xx[2, :, :, :] + end + + # Remapping three fields (more than the buffer length) + interp_xxx = + Remapping.interpolate(remapper, [coords.x, coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x ≈ interp_xxx[1, :, :, :] + @test interp_x ≈ interp_xxx[2, :, :, :] + @test interp_x ≈ interp_xxx[3, :, :, :] + end + + # Remapping in-place one field + dest = zeros(21, 21) + Remapping.interpolate!(dest, remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest + end + + # Two fields + dest = zeros(2, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :, :] + @test interp_x ≈ dest[2, :, :, :] + end + + # Three fields (more than buffer length) + dest = zeros(3, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :, :] + @test interp_x ≈ dest[2, :, :, :] + @test interp_x ≈ dest[3, :, :, :] + end end end @@ -122,7 +177,8 @@ end hcoords = [Geometry.XYPoint(x, y) for x in xpts, y in ypts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) @@ -143,9 +199,56 @@ end @test interp_z[:, :, end] ≈ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts, y in ypts] end + + # Remapping two fields + interp_xy = Remapping.interpolate(remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x ≈ interp_xy[1, :, :, :] + @test interp_y ≈ interp_xy[2, :, :, :] + end + # Remapping three fields (more than the buffer length) + interp_xyx = Remapping.interpolate(remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x ≈ interp_xyx[1, :, :, :] + @test interp_y ≈ interp_xyx[2, :, :, :] + @test interp_x ≈ interp_xyx[3, :, :, :] + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : + remapper + dest = zeros(21, 21, 21) + Remapping.interpolate!(dest, remapper_1field, coords.x) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest + end + + # Two fields + dest = zeros(2, 21, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :, :] + @test interp_y ≈ dest[2, :, :, :] + end + + # Three fields (more than buffer length) + if !on_gpu + dest = zeros(3, 21, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :, :] + @test interp_y ≈ dest[2, :, :, :] + @test interp_x ≈ dest[3, :, :, :] + end + end + # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords, nothing) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) @@ -159,6 +262,57 @@ end if ClimaComms.iamroot(context) @test interp_y ≈ [y for x in xpts, y in ypts] end + + # Two fields + interp_xy = Remapping.interpolate(horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_xy[1, :, :] ≈ interp_x + @test interp_xy[2, :, :] ≈ interp_y + end + + # Three fields + interp_xyx = + Remapping.interpolate(horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_xyx[1, :, :] ≈ interp_x + @test interp_xyx[2, :, :] ≈ interp_y + @test interp_xyx[3, :, :] ≈ interp_x + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(horiz_space, hcoords) : horiz_remapper + dest = zeros(21, 21) + Remapping.interpolate!(dest, remapper_1field, coords.x) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest + end + + # Two fields + dest = zeros(2, 21, 21) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :] + @test interp_y ≈ dest[2, :, :] + end + + # Three fields (more than buffer length) + if !on_gpu + dest = zeros(3, 21, 21) + Remapping.interpolate!( + dest, + horiz_remapper, + [coords.x, coords.y, coords.x], + ) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :] + @test interp_y ≈ dest[2, :, :] + @test interp_x ≈ dest[3, :, :] + end + end end @@ -203,7 +357,8 @@ end hcoords = [Geometry.XYPoint(x, y) for x in xpts, y in ypts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) @@ -225,9 +380,55 @@ end [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts, y in ypts] end + # Remapping two fields + interp_xy = Remapping.interpolate(remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x ≈ interp_xy[1, :, :, :] + @test interp_y ≈ interp_xy[2, :, :, :] + end + # Remapping three fields (more than the buffer length) + interp_xyx = Remapping.interpolate(remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x ≈ interp_xyx[1, :, :, :] + @test interp_y ≈ interp_xyx[2, :, :, :] + @test interp_x ≈ interp_xyx[3, :, :, :] + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : + remapper + dest = zeros(21, 21, 21) + Remapping.interpolate!(dest, remapper_1field, coords.x) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest + end + + # Two fields + dest = zeros(2, 21, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :, :] + @test interp_y ≈ dest[2, :, :, :] + end + + # Three fields (more than buffer length) + if !on_gpu + dest = zeros(3, 21, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :, :] + @test interp_y ≈ dest[2, :, :, :] + @test interp_x ≈ dest[3, :, :, :] + end + end + # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) @@ -241,6 +442,57 @@ end if ClimaComms.iamroot(context) @test interp_y ≈ [y for x in xpts, y in ypts] end + + # Two fields + interp_xy = Remapping.interpolate(horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_xy[1, :, :] ≈ interp_x + @test interp_xy[2, :, :] ≈ interp_y + end + + # Three fields + interp_xyx = + Remapping.interpolate(horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_xyx[1, :, :] ≈ interp_x + @test interp_xyx[2, :, :] ≈ interp_y + @test interp_xyx[3, :, :] ≈ interp_x + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(horiz_space, hcoords) : horiz_remapper + dest = zeros(21, 21) + Remapping.interpolate!(dest, remapper_1field, coords.x) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest + end + + # Two fields + dest = zeros(2, 21, 21) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :] + @test interp_y ≈ dest[2, :, :] + end + + # Three fields (more than buffer length) + if !on_gpu + dest = zeros(3, 21, 21) + Remapping.interpolate!( + dest, + horiz_remapper, + [coords.x, coords.y, coords.x], + ) + if ClimaComms.iamroot(context) + @test interp_x ≈ dest[1, :, :] + @test interp_y ≈ dest[2, :, :] + @test interp_x ≈ dest[3, :, :] + end + end end @testset "3D sphere" begin @@ -274,7 +526,8 @@ end [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) coords = Fields.coordinate_field(hv_center_space) @@ -301,16 +554,67 @@ end [1000.0 * (29 / 30 + 30 / 30) / 2 for x in longpts, y in latpts] end - # Test interpolation in place + # Remapping two fields + interp_long_lat = + Remapping.interpolate(remapper, [sind.(coords.long), sind.(coords.lat)]) + if ClimaComms.iamroot(context) + @test interp_sin_long ≈ interp_long_lat[1, :, :, :] + @test interp_sin_lat ≈ interp_long_lat[2, :, :, :] + end + # Remapping three fields (more than the buffer length) + interp_long_lat_long = Remapping.interpolate( + remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long ≈ interp_long_lat_long[1, :, :, :] + @test interp_sin_lat ≈ interp_long_lat_long[2, :, :, :] + @test interp_sin_long ≈ interp_long_lat_long[3, :, :, :] + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : + remapper dest = zeros(21, 21, 21) - Remapping.interpolate!(dest, remapper, sind.(coords.lat)) + Remapping.interpolate!(dest, remapper_1field, sind.(coords.long)) + if ClimaComms.iamroot(context) + @test interp_sin_long ≈ dest + end + + # Two fields + dest = zeros(2, 21, 21, 21) + Remapping.interpolate!( + dest, + remapper, + [sind.(coords.long), sind.(coords.lat)], + ) if ClimaComms.iamroot(context) - @test dest ≈ [sind(y) for x in longpts, y in latpts, z in zpts] rtol = 0.01 + @test interp_sin_long ≈ dest[1, :, :, :] + @test interp_sin_lat ≈ dest[2, :, :, :] + end + + # Three fields (more than buffer length) + if !on_gpu + dest = zeros(3, 21, 21, 21) + Remapping.interpolate!( + dest, + remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long ≈ dest[1, :, :, :] + @test interp_sin_lat ≈ dest[2, :, :, :] + @test interp_sin_long ≈ dest[3, :, :, :] + end end # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords, nothing) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) @@ -324,4 +628,64 @@ end if ClimaComms.iamroot(context) @test interp_sin_lat ≈ [sind(y) for x in longpts, y in latpts] rtol = 0.01 end + + # Two fields + interp_sin_long_lat = Remapping.interpolate( + horiz_remapper, + [sind.(coords.long), sind.(coords.lat)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long_lat[1, :, :] ≈ interp_sin_long + @test interp_sin_long_lat[2, :, :] ≈ interp_sin_lat + end + + # Three fields + interp_sin_long_lat_long = Remapping.interpolate( + horiz_remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long_lat_long[1, :, :] ≈ interp_sin_long + @test interp_sin_long_lat_long[2, :, :] ≈ interp_sin_lat + @test interp_sin_long_lat_long[3, :, :] ≈ interp_sin_long + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(horiz_space, hcoords) : horiz_remapper + dest = zeros(21, 21) + Remapping.interpolate!(dest, remapper_1field, sind.(coords.long)) + if ClimaComms.iamroot(context) + @test interp_sin_long ≈ dest + end + + # Two fields + dest = zeros(2, 21, 21) + Remapping.interpolate!( + dest, + horiz_remapper, + [sind.(coords.long), sind.(coords.lat)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long ≈ dest[1, :, :] + @test interp_sin_lat ≈ dest[2, :, :] + end + + # Three fields (more than buffer length) + if !on_gpu + dest = zeros(3, 21, 21) + Remapping.interpolate!( + dest, + horiz_remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long ≈ dest[1, :, :] + @test interp_sin_lat ≈ dest[2, :, :] + @test interp_sin_long ≈ dest[3, :, :] + end + end end