diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2683663ddb..45509930cd 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -243,8 +243,14 @@ steps: agents: slurm_ntasks: 4 - - label: "Unit: distributed remapping" - key: distributed_remapping + - label: "Unit: distributed remapping (1 process)" + key: distributed_remapping_1proc + command: "julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl" + env: + CLIMACOMMS_DEVICE: "CPU" + + - label: "Unit: distributed remapping (2 processes)" + key: distributed_remapping_2procs command: "srun julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl" env: CLIMACOMMS_CONTEXT: "MPI" @@ -252,21 +258,23 @@ steps: agents: slurm_ntasks: 2 - - label: "Unit: distributed remapping (1 process)" - key: distributed_remapping_1proc + - label: "Unit: distributed remapping with CUDA (1 process)" + key: distributed_remapping_gpu_1proc command: "julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl" env: - CLIMACOMMS_DEVICE: "CPU" + CLIMACOMMS_DEVICE: "CUDA" + agents: + slurm_gpus: 1 - - label: "Unit: distributed remapping with CUDA" - key: distributed_remapping_gpu + - label: "Unit: distributed remapping with CUDA (2 processes)" + key: distributed_remapping_gpu_2procs command: "srun julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl" env: CLIMACOMMS_CONTEXT: "MPI" CLIMACOMMS_DEVICE: "CUDA" agents: slurm_ntasks: 2 - slurm_gpus: 1 + slurm_gpus_per_task: 1 - label: "Unit: distributed gather" key: unit_distributed_gather4 diff --git a/NEWS.md b/NEWS.md index 81444b5399..98007489e1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,12 @@ ClimaCore.jl Release Notes main ------- +- ![][badge-πŸ€–precisionΞ”] ![][badge-πŸ’₯breaking] `Remapper`s can now process + multiple `Field`s at the same time if created with some `buffer_lenght > 1`. + PR ([#1669](https://github.com/CliMA/ClimaCore.jl/pull/1669)) + Machine-precision differences are expected. This change is breaking because + remappers now return the same array type as the input field. + v0.13.4 ------- diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index 72178f8519..5e0c9d5cc6 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -78,6 +78,11 @@ # process-local interpolate points in the correct shape with respect to the global # interpolation (and where to collect results) # +# To process multiple Fields at the same time, some of the scratch spaces gain an extra +# dimension (buffer_length). With this extra dimension, we can batch the work and process up +# to buffer_length fields at the same time. This reduces the number of kernel launches and +# MPI calls. +# # For GPU runs, we store all the vectors as CuArrays on the GPU. """ @@ -124,6 +129,7 @@ struct Remapper{ T8 <: AbstractArray, T9 <: AbstractArray, T10 <: AbstractArray, + T11 <: Union{Tuple{Colon}, Tuple{Colon, Colon}, Tuple{Colon, Colon, Colon}}, } comms_ctx::CC @@ -170,25 +176,32 @@ struct Remapper{ # Scratch space where we save the process-local interpolated values. We keep overwriting # this to avoid extra allocations. This is a linear array with the same length as - # local_horiz_indices. + # local_horiz_indices with an extra dimension of length buffer_length added. _local_interpolated_values::T8 # Scratch space where we save the process-local field value. We keep overwriting this to # avoid extra allocations. Ideally, we wouldn't need this and we would use views for - # everything. This has dimensions (Nq) or (Nq, Nq) depending if the horizontal space is - # 1D or 2D. + # everything. This has dimensions (Nq, ) or (Nq, Nq, ) + # depending if the horizontal space is 1D or 2D. _field_values::T9 # Storage area where the interpolated values are saved. This is meaningful only for the - # root process and gets filled by a interpolate call. This has dimensions (H, V), where - # H is the size of target_hcoords and V of target_zcoords. In other words, this is the - # expected output array. + # root process and gets filled by a interpolate call. This has dimensions + # (H, V, buffer_length), where H is the size of target_hcoords and V of target_zcoords. + # In other words, this is the expected output array. _interpolated_values::T10 + + # Maximum number of Fields that can be interpolated at any given time + buffer_length::Int + + # A tuple of Colons (1, 2, or 3), used to more easily get views into arrays with unknown + # dimension (1-3D) + colons::T11 end """ - Remapper(space, target_hcoords, target_zcoords) - Remapper(space, target_hcoords) + Remapper(space, target_hcoords, target_zcoords; buffer_length = 1) + Remapper(space, target_hcoords; buffer_length = 1) Return a `Remapper` responsible for interpolating any `Field` defined on the given `space` to the Cartesian product of `target_hcoords` with `target_zcoords`. @@ -199,11 +212,21 @@ The `Remapper` is designed to not be tied to any particular `Field`. You can use `Remapper` for any `Field` as long as they are all defined on the same `topology`. `Remapper` is the main argument to the `interpolate` function. + +Keyword arguments +================= + +`buffer_length` is size of the internal buffer in the Remapper to store intermediate values +for interpolation. Effectively, this controls how many fields can be remapped simultaneously +in `interpolate`. When more fields than `buffer_length` are passed, the remapper will batch +the work in sizes of `buffer_length`. + """ function Remapper( space::Spaces.AbstractSpace, target_hcoords::AbstractArray, - target_zcoords::Union{AbstractArray, Nothing}, + target_zcoords::Union{AbstractArray, Nothing}; + buffer_length::Int = 1, ) comms_ctx = ClimaComms.context(space) @@ -277,7 +300,7 @@ function Remapper( local_horiz_indices = ArrayType(local_horiz_indices) field_values_size = ntuple(_ -> Nq, num_hdims) - field_values = ArrayType(zeros(FT, field_values_size)) + field_values = ArrayType(zeros(FT, field_values_size...)) # We represent interpolation onto an horizontal slab as an empty list of zcoords if isnothing(target_zcoords) || isempty(target_zcoords) @@ -285,9 +308,11 @@ function Remapper( vert_interpolation_weights = nothing vert_bounding_indices = nothing local_interpolated_values = - ArrayType(zeros(FT, size(local_horiz_indices))) - interpolated_values = - ArrayType(zeros(FT, size(local_target_hcoords_bitmask))) + ArrayType(zeros(FT, (size(local_horiz_indices)..., buffer_length))) + interpolated_values = ArrayType( + zeros(FT, (size(local_target_hcoords_bitmask)..., buffer_length)), + ) + num_dims = num_hdims else vert_interpolation_weights = ArrayType(vertical_interpolation_weights(space, target_zcoords)) @@ -297,16 +322,33 @@ function Remapper( # We have to add one extra dimension with respect to the bitmask/local_horiz_indices # because we are going to store the values for the columns local_interpolated_values = ArrayType( - zeros(FT, (size(local_horiz_indices)..., length(target_zcoords))), + zeros( + FT, + ( + size(local_horiz_indices)..., + length(target_zcoords), + buffer_length, + ), + ), ) interpolated_values = ArrayType( zeros( FT, - (size(local_target_hcoords_bitmask)..., length(target_zcoords)), + ( + size(local_target_hcoords_bitmask)..., + length(target_zcoords), + buffer_length, + ), ), ) + num_dims = num_hdims + 1 end + # We don't know how many dimensions an array might have, so we define a colons object + # that we can use to index with array[colons...] + + colons = ntuple(_ -> Colon(), num_dims) + return Remapper( comms_ctx, space, @@ -320,22 +362,27 @@ function Remapper( local_interpolated_values, field_values, interpolated_values, + buffer_length, + colons, ) end -Remapper(space::Spaces.AbstractSpace, target_hcoords::AbstractArray) = - Remapper(space, target_hcoords, nothing) +Remapper( + space::Spaces.AbstractSpace, + target_hcoords::AbstractArray; + buffer_length::Int = 1, +) = Remapper(space, target_hcoords, nothing; buffer_length) """ _set_interpolated_values!(remapper, field) -Change the local state of `remapper` by performing interpolation of `Fields` on the vertical +Change the local state of `remapper` by performing interpolation of `fields` on the vertical and horizontal points. """ -function _set_interpolated_values!(remapper::Remapper, field::Fields.Field) +function _set_interpolated_values!(remapper::Remapper, fields) _set_interpolated_values!( remapper._local_interpolated_values, - field, + fields, remapper._field_values, remapper.local_horiz_indices, remapper.local_horiz_interpolation_weights, @@ -344,52 +391,55 @@ function _set_interpolated_values!(remapper::Remapper, field::Fields.Field) ) end +# CPU, 3D case function set_interpolated_values_cpu_kernel!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, (I1, I2)::NTuple{2}, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) - - # Reading values from field_values is expensive, so we try to limit the number of reads. We can do - # this because multiple target points might be all contained in the same element. - prev_vindex, prev_lidx = -1, -1 - @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) - (v_lo, v_hi) = vert_bounding_indices[vindex] - for (out_index, h) in enumerate(local_horiz_indices) - # If we are no longer in the same element, read the field values again - if prev_lidx != h || prev_vindex != vindex - for j in 1:Nq, i in 1:Nq - scratch_field_values[i, j] = ( - A * field_values[i, j, nothing, v_lo, h] + - B * field_values[i, j, nothing, v_hi, h] - ) + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + + # Reading values from field_values is expensive, so we try to limit the number of reads. We can do + # this because multiple target points might be all contained in the same element. + prev_vindex, prev_lidx = -1, -1 + @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) + (v_lo, v_hi) = vert_bounding_indices[vindex] + for (out_index, h) in enumerate(local_horiz_indices) + # If we are no longer in the same element, read the field values again + if prev_lidx != h || prev_vindex != vindex + for j in 1:Nq, i in 1:Nq + scratch_field_values[i, j] = ( + A * field_values[i, j, nothing, v_lo, h] + + B * field_values[i, j, nothing, v_hi, h] + ) + end + prev_vindex, prev_lidx = vindex, h end - prev_vindex, prev_lidx = vindex, h - end - tmp = zero(FT) + tmp = zero(FT) - for j in 1:Nq, i in 1:Nq - tmp += - I1[out_index, i] * - I2[out_index, j] * - scratch_field_values[i, j] + for j in 1:Nq, i in 1:Nq + tmp += + I1[out_index, i] * + I2[out_index, j] * + scratch_field_values[i, j] + end + out[out_index, vindex, field_index] = tmp end - out[out_index, vindex] = tmp end end end +# GPU, 3D case function set_interpolated_values_kernel!( out::AbstractArray, (I1, I2)::NTuple{2}, @@ -398,32 +448,46 @@ function set_interpolated_values_kernel!( vert_bounding_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_vert = length(vert_bounding_indices) + num_fields = length(field_values) - hindex = blockIdx().x - vindex = threadIdx().x - index = vindex + (hindex - 1) * blockDim().x - index > length(out) && return nothing + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - h = local_horiz_indices[hindex] - v_lo, v_hi = vert_bounding_indices[vindex] - A, B = vert_interpolation_weights[vindex] + totalThreadsX = gridDim().x * blockDim().x + totalThreadsY = gridDim().y * blockDim().y + totalThreadsZ = gridDim().z * blockDim().z _, Nq = size(I1) - out[hindex, vindex] = 0 - for j in 1:Nq, i in 1:Nq - out[hindex, vindex] += - I1[hindex, i] * - I2[hindex, j] * - ( - A * field_values[i, j, nothing, v_lo, h] + - B * field_values[i, j, nothing, v_hi, h] - ) + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for j in vindex:totalThreadsY:num_vert + v_lo, v_hi = vert_bounding_indices[j] + A, B = vert_interpolation_weights[j] + for k in findex:totalThreadsZ:num_fields + if i ≀ num_horiz && j ≀ num_vert && k ≀ num_fields + out[i, j, k] = 0 + for t in 1:Nq, s in 1:Nq + out[i, j, k] += + I1[i, t] * + I2[i, s] * + ( + A * field_values[k][t, s, nothing, v_lo, h] + + B * field_values[k][t, s, nothing, v_hi, h] + ) + end + end + end + end end - return nothing end +# GPU, 2D case function set_interpolated_values_kernel!( out::AbstractArray, (I,)::NTuple{1}, @@ -432,87 +496,105 @@ function set_interpolated_values_kernel!( vert_bounding_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_vert = length(vert_bounding_indices) + num_fields = length(field_values) - hindex = blockIdx().x - vindex = threadIdx().x - index = vindex + (hindex - 1) * blockDim().x - index > length(out) && return nothing + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - h = local_horiz_indices[hindex] - v_lo, v_hi = vert_bounding_indices[vindex] - A, B = vert_interpolation_weights[vindex] + totalThreadsX = gridDim().x * blockDim().x + totalThreadsY = gridDim().y * blockDim().y + totalThreadsZ = gridDim().z * blockDim().z _, Nq = size(I) - out[hindex, vindex] = 0 - for i in 1:Nq - out[hindex, vindex] += - I[hindex, i] * ( - A * field_values[i, nothing, nothing, v_lo, h] + - B * field_values[i, nothing, nothing, v_hi, h] - ) + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for j in vindex:totalThreadsY:num_vert + v_lo, v_hi = vert_bounding_indices[j] + A, B = vert_interpolation_weights[j] + for k in findex:totalThreadsZ:num_fields + if i ≀ num_horiz && j ≀ num_vert && k ≀ num_fields + out[i, j, k] = 0 + for t in 1:Nq + out[i, j, k] += + I[i, t] * + I[i, s] * + ( + A * + field_values[k][t, nothing, nothing, v_lo, h] + + B * + field_values[k][t, nothing, nothing, v_hi, h] + ) + end + end + end + end end - return nothing end +# CPU, 2D case function set_interpolated_values_cpu_kernel!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, (I,)::NTuple{1}, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + + # Reading values from field_values is expensive, so we try to limit the number of reads. We can do + # this because multiple target points might be all contained in the same element. + prev_vindex, prev_lidx = -1, -1 + @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) + (v_lo, v_hi) = vert_bounding_indices[vindex] + + for (out_index, h) in enumerate(local_horiz_indices) + # If we are no longer in the same element, read the field values again + if prev_lidx != h || prev_vindex != vindex + for i in 1:Nq + scratch_field_values[i] = ( + A * field_values[i, nothing, nothing, v_lo, h] + + B * field_values[i, nothing, nothing, v_hi, h] + ) + end + prev_vindex, prev_lidx = vindex, h + end - # Reading values from field_values is expensive, so we try to limit the number of reads. We can do - # this because multiple target points might be all contained in the same element. - prev_vindex, prev_lidx = -1, -1 - @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) - (v_lo, v_hi) = vert_bounding_indices[vindex] + tmp = zero(FT) - for (out_index, h) in enumerate(local_horiz_indices) - # If we are no longer in the same element, read the field values again - if prev_lidx != h || prev_vindex != vindex for i in 1:Nq - scratch_field_values[i] = ( - A * field_values[i, nothing, nothing, v_lo, h] + - B * field_values[i, nothing, nothing, v_hi, h] - ) + tmp += I[out_index, i] * scratch_field_values[i] end - prev_vindex, prev_lidx = vindex, h + out[out_index, vindex, field_index] = tmp end - - tmp = zero(FT) - - for i in 1:Nq - tmp += I[out_index, i] * scratch_field_values[i] - end - out[out_index, vindex] = tmp end end end function _set_interpolated_values!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, scratch_field_values, local_horiz_indices, interpolation_matrix, vert_interpolation_weights::AbstractArray, vert_bounding_indices::AbstractArray, ) - - field_values = Fields.field_values(field) - - if ClimaComms.device(field) isa ClimaComms.CUDADevice + if ClimaComms.device(first(fields)) isa ClimaComms.CUDADevice + # FIXME: Avoid allocation of tuple + field_values = tuple(map(f -> Fields.field_values(f), fields)...) nblocks, _ = size(interpolation_matrix[1]) nthreads = length(vert_interpolation_weights) @cuda always_inline = true threads = (nthreads) blocks = (nblocks) set_interpolated_values_kernel!( @@ -526,20 +608,20 @@ function _set_interpolated_values!( else set_interpolated_values_cpu_kernel!( out, - field, + fields, interpolation_matrix, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) end end +# Horizontal function _set_interpolated_values!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, _scratch_field_values, local_horiz_indices, local_horiz_interpolation_weights, @@ -547,16 +629,17 @@ function _set_interpolated_values!( ::Nothing, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) hdims = length(local_horiz_interpolation_weights) hdims in (1, 2) || error("Cannot handle $hdims horizontal dimensions") if ClimaComms.device(space) isa ClimaComms.CUDADevice + # FIXME: Avoid allocation of tuple + field_values = tuple(map(f -> Fields.field_values(f), fields)...) nitems = length(out) nthreads, nblocks = Topologies._configure_threadblock(nitems) @cuda always_inline = true threads = (nthreads) blocks = (nblocks) set_interpolated_values_kernel!( @@ -566,20 +649,23 @@ function _set_interpolated_values!( field_values, ) else - for (out_index, h) in enumerate(local_horiz_indices) - out[out_index] = zero(FT) - if hdims == 2 - for j in 1:Nq, i in 1:Nq - out[out_index] += - local_horiz_interpolation_weights[1][out_index, i] * - local_horiz_interpolation_weights[2][out_index, j] * - field_values[i, j, nothing, nothing, h] - end - elseif hdims == 1 - for i in 1:Nq - out[out_index] += - local_horiz_interpolation_weights[1][out_index, i] * - field_values[i, nothing, nothing, nothing, h] + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + for (out_index, h) in enumerate(local_horiz_indices) + out[out_index, field_index] = zero(FT) + if hdims == 2 + for j in 1:Nq, i in 1:Nq + out[out_index, field_index] += + local_horiz_interpolation_weights[1][out_index, i] * + local_horiz_interpolation_weights[2][out_index, j] * + field_values[i, j, nothing, nothing, h] + end + elseif hdims == 1 + for i in 1:Nq + out[out_index, field_index] += + local_horiz_interpolation_weights[1][out_index, i] * + field_values[i, nothing, nothing, nothing, h] + end end end end @@ -592,21 +678,32 @@ function set_interpolated_values_kernel!( local_horiz_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_fields = length(field_values) + + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - index > length(out) && return nothing + totalThreadsX = gridDim().x * blockDim().x + totalThreadsZ = gridDim().z * blockDim().z - h = local_horiz_indices[index] _, Nq = size(I1) - out[index] = 0 - for j in 1:Nq, i in 1:Nq - out[index] += - I1[index, i] * - I2[index, j] * - field_values[i, j, nothing, nothing, h] + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for k in findex:totalThreadsZ:num_fields + if i ≀ num_horiz && k ≀ num_fields + out[i, k] = 0 + for t in 1:Nq, s in 1:Nq + out[i, k] += + I1[i, t] * + I2[i, s] * + field_values[k][t, s, nothing, nothing, h] + end + end + end end - return nothing end @@ -616,46 +713,59 @@ function set_interpolated_values_kernel!( local_horiz_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_fields = length(field_values) + + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - index > length(out) && return nothing + totalThreadsX = gridDim().x * blockDim().x + totalThreadsZ = gridDim().z * blockDim().z - h = local_horiz_indices[index] _, Nq = size(I) - out[index] = 0 - for i in 1:Nq - out[index] += - I[index, i] * field_values[i, nothing, nothing, nothing, h] + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for k in findex:totalThreadsZ:num_fields + if i ≀ num_horiz && k ≀ num_fields + out[i, k] = 0 + for t in 1:Nq, s in 1:Nq + out[i, k] += + I[i, i] * + field_values[k][t, nothing, nothing, nothing, h] + end + end + end end - return nothing end - """ - _apply_mpi_bitmask!(remapper::Remapper) + _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int) Change to local (private) state of the `remapper` by applying the MPI bitmask and reconstructing the correct shape for the interpolated values. Internally, `remapper` performs interpolation on a flat list of points, this function moves points around according to MPI-ownership and the expected output shape. + +`num_fields` is the number of fields that have been processed and have to be moved in the +`interpolated_values`. We assume that it is always the first `num_fields` that have to be moved. """ -function _apply_mpi_bitmask!(remapper::Remapper) +function _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int) if isnothing(remapper.target_zcoords) - # _interpolated_values[remapper.local_target_hcoords_bitmask] returns a view on - # space we want to write on - remapper._interpolated_values[remapper.local_target_hcoords_bitmask] .= - remapper._local_interpolated_values + view( + remapper._interpolated_values, + remapper.local_target_hcoords_bitmask, + 1:num_fields, + ) .= view(remapper._local_interpolated_values, :, 1:num_fields) else - # interpolated_values is an array of arrays properly ordered according to the bitmask - - # _interpolated_values[remapper.local_target_hcoords_bitmask, :] returns a - # view on space we want to write on - remapper._interpolated_values[ + view( + remapper._interpolated_values, remapper.local_target_hcoords_bitmask, :, - ] .= remapper._local_interpolated_values + 1:num_fields, + ) .= view(remapper._local_interpolated_values, :, :, 1:num_fields) end end @@ -670,38 +780,80 @@ function _reset_interpolated_values!(remapper::Remapper) end """ - _collect_and_return_interpolated_values!(remapper::Remapper) + _collect_and_return_interpolated_values!(remapper::Remapper, + num_fields::Int) Perform an MPI call to aggregate the interpolated points from all the MPI processes and save the result in the local state of the `remapper`. Only the root process will return the interpolated data. `_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays. + +`num_fields` is the number of fields that have been interpolated in this batch. """ -function _collect_and_return_interpolated_values!(remapper::Remapper) - ClimaComms.reduce!(remapper.comms_ctx, remapper._interpolated_values, +) - return ClimaComms.iamroot(remapper.comms_ctx) ? - Array(remapper._interpolated_values) : nothing +function _collect_and_return_interpolated_values!( + remapper::Remapper, + num_fields::Int, +) + return ClimaComms.reduce( + remapper.comms_ctx, + remapper._interpolated_values[remapper.colons..., 1:num_fields], + +, + ) end -function _collect_interpolated_values!(dest, remapper::Remapper) - # MPI.reduce! seems to behave nicely with respect to CPU/GPU. In particular, - # if the destination is on the CPU, but the source is on the GPU, the values - # are automatically moved. +function _collect_interpolated_values!( + dest, + remapper::Remapper, + index_field_begin::Int, + index_field_end::Int; + only_one_field, +) + + if only_one_field + ClimaComms.reduce!( + remapper.comms_ctx, + remapper._interpolated_values[remapper.colons..., begin], + dest, + +, + ) + return nothing + end + + num_fields = 1 + index_field_end - index_field_begin + ClimaComms.reduce!( remapper.comms_ctx, - remapper._interpolated_values, - dest, + view(remapper._interpolated_values, remapper.colons..., 1:num_fields), + view(dest, remapper.colons..., index_field_begin:index_field_end), +, ) + return nothing end """ - interpolate(remapper::Remapper, field) - interpolate!(dest, remapper::Remapper, field) + batched_ranges(num_fields, buffer_length) + +Partition the indices from 1 to num_fields in such a way that no range is larger than +buffer_length. +""" +function batched_ranges(num_fields, buffer_length) + return [ + (i * buffer_length + 1):(min((i + 1) * buffer_length, num_fields)) for + i in 0:(div((num_fields - 1), buffer_length)) + ] +end + +""" + interpolate(remapper::Remapper, fields) + interpolate!(dest, remapper::Remapper, fields) + +Interpolate the given `field`(s) as prescribed by `remapper`. -Interpolate the given `field` as prescribed by `remapper`. +The optimal number of fields passed is the `buffer_length` of the `remapper`. If +more fields are passed, the `remapper` will batch work with size up to its +`buffer_length`. This call mutates the internal (private) state of the `remapper`. @@ -715,6 +867,9 @@ to be defined on the root process and to be `nothing` for the other processes. Note: `interpolate` allocates new arrays and has some internal type-instability, `interpolate!` is non-allocating and type-stable. +When using `interpolate!`, the `dest`ination has to be the same array type as the +device in use (e.g., `CuArray` for CUDA runs). + Example ======== @@ -734,53 +889,120 @@ remapper = Remapper(space, hcoords, zcoords) int1 = interpolate(remapper, field1) int2 = interpolate(remapper, field2) + +# Or +int12 = interpolate(remapper, [field1, field2]) +# With int1 = int12[1, :, :, :] ``` """ -function interpolate(remapper::Remapper, field::T) where {T <: Fields.Field} +function interpolate(remapper::Remapper, fields) - axes(field) == remapper.space || - error("Field is defined on a different space than remapper") + only_one_field = fields isa Fields.Field + if only_one_field + fields = [fields] + end - # Reset interpolated_values. This is needed because we collect distributed results with - # a + reduction. - _reset_interpolated_values!(remapper) - # Perform the interpolations (horizontal and vertical) - _set_interpolated_values!(remapper, field) - # Reshape the output so that it is a nice grid. - _apply_mpi_bitmask!(remapper) - # Finally, we have to send all the _interpolated_values to root and sum them up to - # obtain the final answer. Only the root will contain something useful. This also moves - # the data off the GPU - return _collect_and_return_interpolated_values!(remapper) + for field in fields + axes(field) == remapper.space || + error("Field is defined on a different space than remapper") + end + + index_field_begin, index_field_end = + 1, min(length(fields), remapper.buffer_length) + + # Partition the indices in such a way that nothing is larger than + # buffer_length + index_ranges = batched_ranges(length(fields), remapper.buffer_length) + + cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1) + + interpolated_values = mapreduce(cat_fn, index_ranges) do range + num_fields = length(range) + + # Reset interpolated_values. This is needed because we collect distributed results + # with a + reduction. + _reset_interpolated_values!(remapper) + # Perform the interpolations (horizontal and vertical) + _set_interpolated_values!( + remapper, + view(fields, index_field_begin:index_field_end), + ) + # Reshape the output so that it is a nice grid. + _apply_mpi_bitmask!(remapper, num_fields) + # Finally, we have to send all the _interpolated_values to root and sum them up to + # obtain the final answer. Only the root will contain something useful. This also + # moves the data off the GPU + ret = _collect_and_return_interpolated_values!(remapper, num_fields) + return ret + end + + # Non-root processes + isnothing(interpolated_values) && return nothing + + return only_one_field ? interpolated_values[remapper.colons..., begin] : + interpolated_values end +# dest has to be allowed to be nothing because interpolation happens only on the root +# process function interpolate!( dest::Union{Nothing, <:AbstractArray}, remapper::Remapper, - field::T, -) where {T <: Fields.Field} - - axes(field) == remapper.space || - error("Field is defined on a different space than remapper") + fields, +) + only_one_field = fields isa Fields.Field + if only_one_field + fields = [fields] + end if !isnothing(dest) # !isnothing(dest) means that this is the root process, in this case, the size have - # to match - size(dest) == size(remapper._interpolated_values) || error( + # to match (ignoring the buffer_length) + dest_size = only_one_field ? size(dest) : size(dest)[1:(end - 1)] + + dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error( "Destination array is not compatible with remapper (size mismatch)", ) + + expected_array_type = + ClimaComms.array_type(ClimaComms.device(remapper.comms_ctx)) + + found_type = nameof(typeof(dest)) + + dest isa expected_array_type || + error("dest is a $found_type, expected $expected_array_type") end + index_field_begin, index_field_end = + 1, min(length(fields), remapper.buffer_length) + + while true + num_fields = 1 + index_field_end - index_field_begin + + # Reset interpolated_values. This is needed because we collect distributed results + # with a + reduction. + _reset_interpolated_values!(remapper) + # Perform the interpolations (horizontal and vertical) + _set_interpolated_values!( + remapper, + view(fields, index_field_begin:index_field_end), + ) + # Reshape the output so that it is a nice grid. + _apply_mpi_bitmask!(remapper, num_fields) + # Finally, we have to send all the _interpolated_values to root and sum them up to + # obtain the final answer. Only the root will contain something useful. This also + # moves the data off the GPU + _collect_interpolated_values!( + dest, + remapper, + index_field_begin, + index_field_end; + only_one_field, + ) - # Reset interpolated_values. This is needed because we collect distributed results with - # a + reduction. - _reset_interpolated_values!(remapper) - # Perform the interpolations (horizontal and vertical) - _set_interpolated_values!(remapper, field) - # Reshape the output so that it is a nice grid. - _apply_mpi_bitmask!(remapper) - # Finally, we have to send all the _interpolated_values to root and sum them - # up to obtain the final answer. This also moves the data off the GPU. The - # output is written to the given destination - _collect_interpolated_values!(dest, remapper) + index_field_end != length(fields) || break + index_field_begin = index_field_begin + remapper.buffer_length + index_field_end = + min(length(fields), index_field_end + remapper.buffer_length) + end return nothing end diff --git a/test/Remapping/distributed_remapping.jl b/test/Remapping/distributed_remapping.jl index 1e7bd7d39b..6ae9dd54f9 100644 --- a/test/Remapping/distributed_remapping.jl +++ b/test/Remapping/distributed_remapping.jl @@ -17,6 +17,7 @@ using ClimaComms const context = ClimaComms.context() const pid, nprocs = ClimaComms.init(context) const device = ClimaComms.device() +ArrayType = ClimaComms.array_type(device) # log output only from root process logger_stream = ClimaComms.iamroot(context) ? stderr : devnull @@ -25,7 +26,18 @@ atexit() do global_logger(prev_logger) end -if !(device isa ClimaComms.CUDADevice) +@testset "Utils" begin + # batched_ranges(num_fields, buffer_length) + @test Remapping.batched_ranges(1, 1) == [1:1] + @test Remapping.batched_ranges(1, 2) == [1:1] + @test Remapping.batched_ranges(2, 2) == [1:2] + @test Remapping.batched_ranges(3, 2) == [1:2, 3:3] +end + +on_gpu = device isa ClimaComms.CUDADevice +broken = false + +if !on_gpu @testset "2D extruded" begin vertdomain = Domains.IntervalDomain( Geometry.ZPoint(0.0), @@ -63,22 +75,67 @@ if !(device isa ClimaComms.CUDADevice) hcoords = [Geometry.XPoint(x) for x in xpts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = Remapping.Remapper( + hv_center_space, + hcoords, + zcoords, + buffer_length = 2, + ) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, z in zpts] + @test Array(interp_x) β‰ˆ [x for x in xpts, z in zpts] end interp_z = Remapping.interpolate(remapper, coords.z) expected_z = [z for x in xpts, z in zpts] if ClimaComms.iamroot(context) - @test interp_z[:, 2:(end - 1)] β‰ˆ expected_z[:, 2:(end - 1)] - @test interp_z[:, 1] β‰ˆ + @test Array(interp_z[:, 2:(end - 1)]) β‰ˆ expected_z[:, 2:(end - 1)] + @test Array(interp_z[:, 1]) β‰ˆ [1000.0 * (0 / 30 + 1 / 30) / 2 for x in xpts] - @test interp_z[:, end] β‰ˆ + @test Array(interp_z[:, end]) β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts] end + + # Remapping two fields + interp_xx = Remapping.interpolate(remapper, [coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xx[:, :, 1] + @test interp_x β‰ˆ interp_xx[:, :, 2] + end + + # Remapping three fields (more than the buffer length) + interp_xxx = + Remapping.interpolate(remapper, [coords.x, coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xxx[:, :, 1] + @test interp_x β‰ˆ interp_xxx[:, :, 2] + @test interp_x β‰ˆ interp_xxx[:, :, 3] + end + + # Remapping in-place one field + dest = ArrayType(zeros(21, 21)) + Remapping.interpolate!(dest, remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + + # Two fields + dest = ArrayType(zeros(21, 21, 2)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_x β‰ˆ dest[:, :, 2] + end + + # Three fields (more than buffer length) + dest = ArrayType(zeros(21, 21, 3)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_x β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] + end end end @@ -122,43 +179,129 @@ end hcoords = [Geometry.XYPoint(x, y) for x in xpts, y in ypts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, y in ypts, z in zpts] + @test Array(interp_x) β‰ˆ [x for x in xpts, y in ypts, z in zpts] end interp_y = Remapping.interpolate(remapper, coords.y) if ClimaComms.iamroot(context) - @test interp_y β‰ˆ [y for x in xpts, y in ypts, z in zpts] + @test Array(interp_y) β‰ˆ [y for x in xpts, y in ypts, z in zpts] end interp_z = Remapping.interpolate(remapper, coords.z) expected_z = [z for x in xpts, y in ypts, z in zpts] if ClimaComms.iamroot(context) - @test interp_z[:, :, 2:(end - 1)] β‰ˆ expected_z[:, :, 2:(end - 1)] - @test interp_z[:, :, 1] β‰ˆ + @test Array(interp_z[:, :, 2:(end - 1)]) β‰ˆ expected_z[:, :, 2:(end - 1)] + @test Array(interp_z[:, :, 1]) β‰ˆ [1000.0 * (0 / 30 + 1 / 30) / 2 for x in xpts, y in ypts] - @test interp_z[:, :, end] β‰ˆ + @test Array(interp_z[:, :, end]) β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts, y in ypts] end + + # Remapping two fields + interp_xy = Remapping.interpolate(remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xy[:, :, :, 1] + @test interp_y β‰ˆ interp_xy[:, :, :, 2] + end + # Remapping three fields (more than the buffer length) + interp_xyx = Remapping.interpolate(remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xyx[:, :, :, 1] + @test interp_y β‰ˆ interp_xyx[:, :, :, 2] + @test interp_x β‰ˆ interp_xyx[:, :, :, 3] + end + + # Remapping in-place one field + dest = ArrayType(zeros(21, 21, 21)) + Remapping.interpolate!(dest, remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + + # Two fields + dest = ArrayType(zeros(21, 21, 21, 2)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] + end + + # Three fields (more than buffer length) + dest = ArrayType(zeros(21, 21, 21, 3)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] + @test interp_x β‰ˆ dest[:, :, :, 3] + end + + # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords, nothing) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) interp_x = Remapping.interpolate(horiz_remapper, coords.x) # Only root has the final result if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, y in ypts] + @test Array(interp_x) β‰ˆ [x for x in xpts, y in ypts] end interp_y = Remapping.interpolate(horiz_remapper, coords.y) if ClimaComms.iamroot(context) - @test interp_y β‰ˆ [y for x in xpts, y in ypts] + @test Array(interp_y) β‰ˆ [y for x in xpts, y in ypts] + end + + # Two fields + interp_xy = Remapping.interpolate(horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_xy[:, :, 1] β‰ˆ interp_x + @test interp_xy[:, :, 2] β‰ˆ interp_y + end + + # Three fields + interp_xyx = + Remapping.interpolate(horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_xyx[:, :, 1] β‰ˆ interp_x + @test interp_xyx[:, :, 2] β‰ˆ interp_y + @test interp_xyx[:, :, 3] β‰ˆ interp_x + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + dest = ArrayType(zeros(21, 21)) + Remapping.interpolate!(dest, horiz_remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest end + + + # Two fields + dest = ArrayType(zeros(21, 21, 2)) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] + end + + # Three fields (more than buffer length) + dest = ArrayType(zeros(21, 21, 3)) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] + end + end @@ -203,43 +346,124 @@ end hcoords = [Geometry.XYPoint(x, y) for x in xpts, y in ypts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, y in ypts, z in zpts] + @test Array(interp_x) β‰ˆ [x for x in xpts, y in ypts, z in zpts] end interp_y = Remapping.interpolate(remapper, coords.y) if ClimaComms.iamroot(context) - @test interp_y β‰ˆ [y for x in xpts, y in ypts, z in zpts] + @test Array(interp_y) β‰ˆ [y for x in xpts, y in ypts, z in zpts] end interp_z = Remapping.interpolate(remapper, coords.z) expected_z = [z for x in xpts, y in ypts, z in zpts] if ClimaComms.iamroot(context) - @test interp_z[:, :, 2:(end - 1)] β‰ˆ expected_z[:, :, 2:(end - 1)] - @test interp_z[:, :, 1] β‰ˆ + @test Array(interp_z[:, :, 2:(end - 1)]) β‰ˆ expected_z[:, :, 2:(end - 1)] + @test Array(interp_z[:, :, 1]) β‰ˆ [1000.0 * (0 / 30 + 1 / 30) / 2 for x in xpts, y in ypts] - @test interp_z[:, :, end] β‰ˆ + @test Array(interp_z[:, :, end]) β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts, y in ypts] end + # Remapping two fields + interp_xy = Remapping.interpolate(remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xy[:, :, :, 1] + @test interp_y β‰ˆ interp_xy[:, :, :, 2] + end + # Remapping three fields (more than the buffer length) + interp_xyx = Remapping.interpolate(remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xyx[:, :, :, 1] + @test interp_y β‰ˆ interp_xyx[:, :, :, 2] + @test interp_x β‰ˆ interp_xyx[:, :, :, 3] + end + + # Remapping in-place one field + dest = ArrayType(zeros(21, 21, 21)) + Remapping.interpolate!(dest, remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + + # Two fields + dest = ArrayType(zeros(21, 21, 21, 2)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] + end + + # Three fields (more than buffer length) + dest = ArrayType(zeros(21, 21, 21, 3)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] + @test interp_x β‰ˆ dest[:, :, :, 3] + end + + # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) interp_x = Remapping.interpolate(horiz_remapper, coords.x) # Only root has the final result if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, y in ypts] + @test Array(interp_x) β‰ˆ [x for x in xpts, y in ypts] end interp_y = Remapping.interpolate(horiz_remapper, coords.y) if ClimaComms.iamroot(context) - @test interp_y β‰ˆ [y for x in xpts, y in ypts] + @test Array(interp_y) β‰ˆ [y for x in xpts, y in ypts] + end + + # Two fields + interp_xy = Remapping.interpolate(horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_xy[:, :, 1] β‰ˆ interp_x + @test interp_xy[:, :, 2] β‰ˆ interp_y + end + + # Three fields + interp_xyx = + Remapping.interpolate(horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_xyx[:, :, 1] β‰ˆ interp_x + @test interp_xyx[:, :, 2] β‰ˆ interp_y + @test interp_xyx[:, :, 3] β‰ˆ interp_x + end + + # Remapping in-place one field + dest = ArrayType(zeros(21, 21)) + Remapping.interpolate!(dest, horiz_remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + + + # Two fields + dest = ArrayType(zeros(21, 21, 2)) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] + end + + # Three fields (more than buffer length) + dest = ArrayType(zeros(21, 21, 3)) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] end end @@ -274,54 +498,151 @@ end [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) coords = Fields.coordinate_field(hv_center_space) interp_sin_long = Remapping.interpolate(remapper, sind.(coords.long)) # Only root has the final result if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ + @test Array(interp_sin_long) β‰ˆ [sind(x) for x in longpts, y in latpts, z in zpts] rtol = 0.01 end interp_sin_lat = Remapping.interpolate(remapper, sind.(coords.lat)) if ClimaComms.iamroot(context) - @test interp_sin_lat β‰ˆ + @test Array(interp_sin_lat) β‰ˆ [sind(y) for x in longpts, y in latpts, z in zpts] rtol = 0.01 end interp_z = Remapping.interpolate(remapper, coords.z) expected_z = [z for x in longpts, y in latpts, z in zpts] if ClimaComms.iamroot(context) - @test interp_z[:, :, 2:(end - 1)] β‰ˆ expected_z[:, :, 2:(end - 1)] - @test interp_z[:, :, 1] β‰ˆ + @test Array(interp_z[:, :, 2:(end - 1)]) β‰ˆ expected_z[:, :, 2:(end - 1)] + @test Array(interp_z[:, :, 1]) β‰ˆ [1000.0 * (0 / 30 + 1 / 30) / 2 for x in longpts, y in latpts] - @test interp_z[:, :, end] β‰ˆ + @test Array(interp_z[:, :, end]) β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in longpts, y in latpts] end - # Test interpolation in place - dest = zeros(21, 21, 21) - Remapping.interpolate!(dest, remapper, sind.(coords.lat)) + # Remapping two fields + interp_long_lat = + Remapping.interpolate(remapper, [sind.(coords.long), sind.(coords.lat)]) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ interp_long_lat[:, :, :, 1] + @test interp_sin_lat β‰ˆ interp_long_lat[:, :, :, 2] + end + # Remapping three fields (more than the buffer length) + interp_long_lat_long = Remapping.interpolate( + remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ interp_long_lat_long[:, :, :, 1] + @test interp_sin_lat β‰ˆ interp_long_lat_long[:, :, :, 2] + @test interp_sin_long β‰ˆ interp_long_lat_long[:, :, :, 3] + end + + # Remapping in-place one field + dest = ArrayType(zeros(21, 21, 21)) + Remapping.interpolate!(dest, remapper, sind.(coords.long)) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest + end + + # Two fields + dest = ArrayType(zeros(21, 21, 21, 2)) + Remapping.interpolate!( + dest, + remapper, + [sind.(coords.long), sind.(coords.lat)], + ) if ClimaComms.iamroot(context) - @test dest β‰ˆ [sind(y) for x in longpts, y in latpts, z in zpts] rtol = 0.01 + @test interp_sin_long β‰ˆ dest[:, :, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, :, 2] + end + + # Three fields (more than buffer length) + dest = ArrayType(zeros(21, 21, 21, 3)) + Remapping.interpolate!( + dest, + remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[:, :, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, :, 2] + @test interp_sin_long β‰ˆ dest[:, :, :, 3] end # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords, nothing) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) interp_sin_long = Remapping.interpolate(horiz_remapper, sind.(coords.long)) # Only root has the final result if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ [sind(x) for x in longpts, y in latpts] rtol = 0.01 + @test Array(interp_sin_long) β‰ˆ [sind(x) for x in longpts, y in latpts] rtol = 0.01 end interp_sin_lat = Remapping.interpolate(horiz_remapper, sind.(coords.lat)) if ClimaComms.iamroot(context) - @test interp_sin_lat β‰ˆ [sind(y) for x in longpts, y in latpts] rtol = 0.01 + @test Array(interp_sin_lat) β‰ˆ [sind(y) for x in longpts, y in latpts] rtol = 0.01 + end + + # Two fields + interp_sin_long_lat = Remapping.interpolate( + horiz_remapper, + [sind.(coords.long), sind.(coords.lat)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long_lat[:, :, 1] β‰ˆ interp_sin_long + @test interp_sin_long_lat[:, :, 2] β‰ˆ interp_sin_lat + end + + # Three fields + interp_sin_long_lat_long = Remapping.interpolate( + horiz_remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long_lat_long[:, :, 1] β‰ˆ interp_sin_long + @test interp_sin_long_lat_long[:, :, 2] β‰ˆ interp_sin_lat + @test interp_sin_long_lat_long[:, :, 3] β‰ˆ interp_sin_long + end + + # Remapping in-place one field + dest = ArrayType(zeros(21, 21)) + Remapping.interpolate!(dest, horiz_remapper, sind.(coords.long)) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest + end + + # Two fields + dest = ArrayType(zeros(21, 21, 2)) + Remapping.interpolate!( + dest, + horiz_remapper, + [sind.(coords.long), sind.(coords.lat)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[:, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, 2] + end + + # Three fields (more than buffer length) + dest = ArrayType(zeros(21, 21, 3)) + Remapping.interpolate!( + dest, + horiz_remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[:, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, 2] + @test interp_sin_long β‰ˆ dest[:, :, 3] end end