From 38ffafd4d640669c509f0728cd1e2962079919e6 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Wed, 10 Apr 2024 21:37:25 -0700 Subject: [PATCH 1/3] Allow remapper to take multiple Fields This PR increases the rank of some internal arrays of the remapper by 1. The new dimension is to allow the remapper to process multiple Fields at the same time. The idea is the following: when creating a Remapper, one can specify a buffer_length. When the buffer_length is larger than one, the Remapper will preallocate space to be able to interpolate buffer_length at the same time. Then, when `interpolate` is called, the Remapper can work with any number of Fields and the work is divided in batches of buffer_length size. E.g, If buffer_length is 10 and 22 fields are to be interpolated, the work is processed in groups of 10+10+2. The only cost of choosing a large buffer_length is memory (there shouldn't be any runtime penalty in interpolating 1 field with a batch_length of 100). The memory cost of a Remapper is order (B, H, V), where B is the buffer length, H is the number of horizontal points, and V the number of vertical points. For H = 180x90 and V = 50, this means that each buffer costs 51_840_000 bytes (50 MB) for double precision on the root process + 50 MB / N_tasks on each task + base cost that is independent of B. --- .buildkite/pipeline.yml | 24 +- NEWS.md | 3 + src/Remapping/distributed_remapping.jl | 626 ++++++++++++++++-------- test/Remapping/distributed_remapping.jl | 395 ++++++++++++++- 4 files changed, 825 insertions(+), 223 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 2683663ddb..45509930cd 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -243,8 +243,14 @@ steps: agents: slurm_ntasks: 4 - - label: "Unit: distributed remapping" - key: distributed_remapping + - label: "Unit: distributed remapping (1 process)" + key: distributed_remapping_1proc + command: "julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl" + env: + CLIMACOMMS_DEVICE: "CPU" + + - label: "Unit: distributed remapping (2 processes)" + key: distributed_remapping_2procs command: "srun julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl" env: CLIMACOMMS_CONTEXT: "MPI" @@ -252,21 +258,23 @@ steps: agents: slurm_ntasks: 2 - - label: "Unit: distributed remapping (1 process)" - key: distributed_remapping_1proc + - label: "Unit: distributed remapping with CUDA (1 process)" + key: distributed_remapping_gpu_1proc command: "julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl" env: - CLIMACOMMS_DEVICE: "CPU" + CLIMACOMMS_DEVICE: "CUDA" + agents: + slurm_gpus: 1 - - label: "Unit: distributed remapping with CUDA" - key: distributed_remapping_gpu + - label: "Unit: distributed remapping with CUDA (2 processes)" + key: distributed_remapping_gpu_2procs command: "srun julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl" env: CLIMACOMMS_CONTEXT: "MPI" CLIMACOMMS_DEVICE: "CUDA" agents: slurm_ntasks: 2 - slurm_gpus: 1 + slurm_gpus_per_task: 1 - label: "Unit: distributed gather" key: unit_distributed_gather4 diff --git a/NEWS.md b/NEWS.md index 81444b5399..d3e4413d00 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,9 @@ ClimaCore.jl Release Notes main ------- +- ![][badge-πŸ€–precisionΞ”] `Remapper`s can now process multiple `Field`s at the same time if created with some `buffer_lenght > 1`. + PR ([#1669](https://github.com/CliMA/ClimaCore.jl/pull/1669)) Machine-precision differences are expected. + v0.13.4 ------- diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index 72178f8519..874502285d 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -78,6 +78,11 @@ # process-local interpolate points in the correct shape with respect to the global # interpolation (and where to collect results) # +# To process multiple Fields at the same time, some of the scratch spaces gain an extra +# dimension (buffer_length). With this extra dimension, we can batch the work and process up +# to buffer_length fields at the same time. This reduces the number of kernel launches and +# MPI calls. +# # For GPU runs, we store all the vectors as CuArrays on the GPU. """ @@ -124,6 +129,7 @@ struct Remapper{ T8 <: AbstractArray, T9 <: AbstractArray, T10 <: AbstractArray, + T11 <: Union{Tuple{Colon}, Tuple{Colon, Colon}, Tuple{Colon, Colon, Colon}}, } comms_ctx::CC @@ -170,25 +176,32 @@ struct Remapper{ # Scratch space where we save the process-local interpolated values. We keep overwriting # this to avoid extra allocations. This is a linear array with the same length as - # local_horiz_indices. + # local_horiz_indices with an extra dimension of length buffer_length added. _local_interpolated_values::T8 # Scratch space where we save the process-local field value. We keep overwriting this to # avoid extra allocations. Ideally, we wouldn't need this and we would use views for - # everything. This has dimensions (Nq) or (Nq, Nq) depending if the horizontal space is - # 1D or 2D. + # everything. This has dimensions (buffer_length, Nq) or (buffer_length, Nq, Nq) + # depending if the horizontal space is 1D or 2D. _field_values::T9 # Storage area where the interpolated values are saved. This is meaningful only for the - # root process and gets filled by a interpolate call. This has dimensions (H, V), where - # H is the size of target_hcoords and V of target_zcoords. In other words, this is the - # expected output array. + # root process and gets filled by a interpolate call. This has dimensions + # (buffer_length, H, V), where H is the size of target_hcoords and V of target_zcoords. + # In other words, this is the expected output array. _interpolated_values::T10 + + # Maximum number of Fields that can be interpolated at any given time + buffer_length::Int + + # A tuple of Colons (1, 2, or 3), used to more easily get views into arrays with unknown + # dimension (1-3D) + colons::T11 end """ - Remapper(space, target_hcoords, target_zcoords) - Remapper(space, target_hcoords) + Remapper(space, target_hcoords, target_zcoords; buffer_length = 1) + Remapper(space, target_hcoords; buffer_length = 1) Return a `Remapper` responsible for interpolating any `Field` defined on the given `space` to the Cartesian product of `target_hcoords` with `target_zcoords`. @@ -199,11 +212,21 @@ The `Remapper` is designed to not be tied to any particular `Field`. You can use `Remapper` for any `Field` as long as they are all defined on the same `topology`. `Remapper` is the main argument to the `interpolate` function. + +Keyword arguments +================= + +`buffer_length` is size of the internal buffer in the Remapper to store intermediate values +for interpolation. Effectively, this controls how many fields can be remapped simultaneously +in `interpolate`. When more fields than `buffer_length` are passed, the remapper will batch +the work in sizes of `buffer_length`. + """ function Remapper( space::Spaces.AbstractSpace, target_hcoords::AbstractArray, - target_zcoords::Union{AbstractArray, Nothing}, + target_zcoords::Union{AbstractArray, Nothing}; + buffer_length::Int = 1, ) comms_ctx = ClimaComms.context(space) @@ -277,7 +300,7 @@ function Remapper( local_horiz_indices = ArrayType(local_horiz_indices) field_values_size = ntuple(_ -> Nq, num_hdims) - field_values = ArrayType(zeros(FT, field_values_size)) + field_values = ArrayType(zeros(FT, field_values_size...)) # We represent interpolation onto an horizontal slab as an empty list of zcoords if isnothing(target_zcoords) || isempty(target_zcoords) @@ -285,9 +308,11 @@ function Remapper( vert_interpolation_weights = nothing vert_bounding_indices = nothing local_interpolated_values = - ArrayType(zeros(FT, size(local_horiz_indices))) - interpolated_values = - ArrayType(zeros(FT, size(local_target_hcoords_bitmask))) + ArrayType(zeros(FT, (buffer_length, size(local_horiz_indices)...))) + interpolated_values = ArrayType( + zeros(FT, (buffer_length, size(local_target_hcoords_bitmask)...)), + ) + num_dims = num_hdims else vert_interpolation_weights = ArrayType(vertical_interpolation_weights(space, target_zcoords)) @@ -297,16 +322,33 @@ function Remapper( # We have to add one extra dimension with respect to the bitmask/local_horiz_indices # because we are going to store the values for the columns local_interpolated_values = ArrayType( - zeros(FT, (size(local_horiz_indices)..., length(target_zcoords))), + zeros( + FT, + ( + buffer_length, + size(local_horiz_indices)..., + length(target_zcoords), + ), + ), ) interpolated_values = ArrayType( zeros( FT, - (size(local_target_hcoords_bitmask)..., length(target_zcoords)), + ( + buffer_length, + size(local_target_hcoords_bitmask)..., + length(target_zcoords), + ), ), ) + num_dims = num_hdims + 1 end + # We don't know how many dimensions an array might have, so we define a colons object + # that we can use to index with array[colons...] + + colons = ntuple(_ -> Colon(), num_dims) + return Remapper( comms_ctx, space, @@ -320,22 +362,27 @@ function Remapper( local_interpolated_values, field_values, interpolated_values, + buffer_length, + colons, ) end -Remapper(space::Spaces.AbstractSpace, target_hcoords::AbstractArray) = - Remapper(space, target_hcoords, nothing) +Remapper( + space::Spaces.AbstractSpace, + target_hcoords::AbstractArray; + buffer_length::Int = 1, +) = Remapper(space, target_hcoords, nothing; buffer_length) """ _set_interpolated_values!(remapper, field) -Change the local state of `remapper` by performing interpolation of `Fields` on the vertical +Change the local state of `remapper` by performing interpolation of `fields` on the vertical and horizontal points. """ -function _set_interpolated_values!(remapper::Remapper, field::Fields.Field) +function _set_interpolated_values!(remapper::Remapper, fields) _set_interpolated_values!( remapper._local_interpolated_values, - field, + fields, remapper._field_values, remapper.local_horiz_indices, remapper.local_horiz_interpolation_weights, @@ -344,52 +391,55 @@ function _set_interpolated_values!(remapper::Remapper, field::Fields.Field) ) end +# CPU, 3D case function set_interpolated_values_cpu_kernel!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, (I1, I2)::NTuple{2}, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) - - # Reading values from field_values is expensive, so we try to limit the number of reads. We can do - # this because multiple target points might be all contained in the same element. - prev_vindex, prev_lidx = -1, -1 - @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) - (v_lo, v_hi) = vert_bounding_indices[vindex] - for (out_index, h) in enumerate(local_horiz_indices) - # If we are no longer in the same element, read the field values again - if prev_lidx != h || prev_vindex != vindex - for j in 1:Nq, i in 1:Nq - scratch_field_values[i, j] = ( - A * field_values[i, j, nothing, v_lo, h] + - B * field_values[i, j, nothing, v_hi, h] - ) + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + + # Reading values from field_values is expensive, so we try to limit the number of reads. We can do + # this because multiple target points might be all contained in the same element. + prev_vindex, prev_lidx = -1, -1 + @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) + (v_lo, v_hi) = vert_bounding_indices[vindex] + for (out_index, h) in enumerate(local_horiz_indices) + # If we are no longer in the same element, read the field values again + if prev_lidx != h || prev_vindex != vindex + for j in 1:Nq, i in 1:Nq + scratch_field_values[i, j] = ( + A * field_values[i, j, nothing, v_lo, h] + + B * field_values[i, j, nothing, v_hi, h] + ) + end + prev_vindex, prev_lidx = vindex, h end - prev_vindex, prev_lidx = vindex, h - end - tmp = zero(FT) + tmp = zero(FT) - for j in 1:Nq, i in 1:Nq - tmp += - I1[out_index, i] * - I2[out_index, j] * - scratch_field_values[i, j] + for j in 1:Nq, i in 1:Nq + tmp += + I1[out_index, i] * + I2[out_index, j] * + scratch_field_values[i, j] + end + out[field_index, out_index, vindex] = tmp end - out[out_index, vindex] = tmp end end end +# GPU, 3D case function set_interpolated_values_kernel!( out::AbstractArray, (I1, I2)::NTuple{2}, @@ -398,32 +448,46 @@ function set_interpolated_values_kernel!( vert_bounding_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_vert = length(vert_bounding_indices) + num_fields = length(field_values) - hindex = blockIdx().x - vindex = threadIdx().x - index = vindex + (hindex - 1) * blockDim().x - index > length(out) && return nothing + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - h = local_horiz_indices[hindex] - v_lo, v_hi = vert_bounding_indices[vindex] - A, B = vert_interpolation_weights[vindex] + totalThreadsX = gridDim().x * blockDim().x + totalThreadsY = gridDim().y * blockDim().y + totalThreadsZ = gridDim().z * blockDim().z _, Nq = size(I1) - out[hindex, vindex] = 0 - for j in 1:Nq, i in 1:Nq - out[hindex, vindex] += - I1[hindex, i] * - I2[hindex, j] * - ( - A * field_values[i, j, nothing, v_lo, h] + - B * field_values[i, j, nothing, v_hi, h] - ) + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for j in vindex:totalThreadsY:num_vert + v_lo, v_hi = vert_bounding_indices[j] + A, B = vert_interpolation_weights[j] + for k in findex:totalThreadsZ:num_fields + if i ≀ num_horiz && j ≀ num_vert && k ≀ num_fields + out[k, i, j] = 0 + for t in 1:Nq, s in 1:Nq + out[k, i, j] += + I1[i, t] * + I2[i, s] * + ( + A * field_values[k][t, s, nothing, v_lo, h] + + B * field_values[k][t, s, nothing, v_hi, h] + ) + end + end + end + end end - return nothing end +# GPU, 2D case function set_interpolated_values_kernel!( out::AbstractArray, (I,)::NTuple{1}, @@ -432,87 +496,105 @@ function set_interpolated_values_kernel!( vert_bounding_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_vert = length(vert_bounding_indices) + num_fields = length(field_values) - hindex = blockIdx().x - vindex = threadIdx().x - index = vindex + (hindex - 1) * blockDim().x - index > length(out) && return nothing + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - h = local_horiz_indices[hindex] - v_lo, v_hi = vert_bounding_indices[vindex] - A, B = vert_interpolation_weights[vindex] + totalThreadsX = gridDim().x * blockDim().x + totalThreadsY = gridDim().y * blockDim().y + totalThreadsZ = gridDim().z * blockDim().z - _, Nq = size(I) + _, Nq = size(I1) - out[hindex, vindex] = 0 - for i in 1:Nq - out[hindex, vindex] += - I[hindex, i] * ( - A * field_values[i, nothing, nothing, v_lo, h] + - B * field_values[i, nothing, nothing, v_hi, h] - ) + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for j in vindex:totalThreadsY:num_vert + v_lo, v_hi = vert_bounding_indices[j] + A, B = vert_interpolation_weights[j] + for k in findex:totalThreadsZ:num_fields + if i ≀ num_horiz && j ≀ num_vert && k ≀ num_fields + out[k, i, j] = 0 + for t in 1:Nq + out[k, i, j] += + I1[i, t] * + I2[i, s] * + ( + A * + field_values[k][t, nothing, nothing, v_lo, h] + + B * + field_values[k][t, nothing, nothing, v_hi, h] + ) + end + end + end + end end - return nothing end +# CPU, 2D case function set_interpolated_values_cpu_kernel!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, (I,)::NTuple{1}, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + + # Reading values from field_values is expensive, so we try to limit the number of reads. We can do + # this because multiple target points might be all contained in the same element. + prev_vindex, prev_lidx = -1, -1 + @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) + (v_lo, v_hi) = vert_bounding_indices[vindex] + + for (out_index, h) in enumerate(local_horiz_indices) + # If we are no longer in the same element, read the field values again + if prev_lidx != h || prev_vindex != vindex + for i in 1:Nq + scratch_field_values[i] = ( + A * field_values[i, nothing, nothing, v_lo, h] + + B * field_values[i, nothing, nothing, v_hi, h] + ) + end + prev_vindex, prev_lidx = vindex, h + end - # Reading values from field_values is expensive, so we try to limit the number of reads. We can do - # this because multiple target points might be all contained in the same element. - prev_vindex, prev_lidx = -1, -1 - @inbounds for (vindex, (A, B)) in enumerate(vert_interpolation_weights) - (v_lo, v_hi) = vert_bounding_indices[vindex] + tmp = zero(FT) - for (out_index, h) in enumerate(local_horiz_indices) - # If we are no longer in the same element, read the field values again - if prev_lidx != h || prev_vindex != vindex for i in 1:Nq - scratch_field_values[i] = ( - A * field_values[i, nothing, nothing, v_lo, h] + - B * field_values[i, nothing, nothing, v_hi, h] - ) + tmp += I[out_index, i] * scratch_field_values[i] end - prev_vindex, prev_lidx = vindex, h + out[field_index, out_index, vindex] = tmp end - - tmp = zero(FT) - - for i in 1:Nq - tmp += I[out_index, i] * scratch_field_values[i] - end - out[out_index, vindex] = tmp end end end function _set_interpolated_values!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, scratch_field_values, local_horiz_indices, interpolation_matrix, vert_interpolation_weights::AbstractArray, vert_bounding_indices::AbstractArray, ) - - field_values = Fields.field_values(field) - - if ClimaComms.device(field) isa ClimaComms.CUDADevice + if ClimaComms.device(first(fields)) isa ClimaComms.CUDADevice + # FIXME: Avoid allocation of tuple + field_values = tuple(map(f -> Fields.field_values(f), fields)...) nblocks, _ = size(interpolation_matrix[1]) nthreads = length(vert_interpolation_weights) @cuda always_inline = true threads = (nthreads) blocks = (nblocks) set_interpolated_values_kernel!( @@ -526,20 +608,20 @@ function _set_interpolated_values!( else set_interpolated_values_cpu_kernel!( out, - field, + fields, interpolation_matrix, local_horiz_indices, vert_interpolation_weights, vert_bounding_indices, scratch_field_values, - field_values, ) end end +# Horizontal function _set_interpolated_values!( out::AbstractArray, - field::Fields.Field, + fields::AbstractArray{<:Fields.Field}, _scratch_field_values, local_horiz_indices, local_horiz_interpolation_weights, @@ -547,16 +629,17 @@ function _set_interpolated_values!( ::Nothing, ) - space = axes(field) + space = axes(first(fields)) FT = Spaces.undertype(space) quad = Spaces.quadrature_style(space) Nq = Quadratures.degrees_of_freedom(quad) - field_values = Fields.field_values(field) hdims = length(local_horiz_interpolation_weights) hdims in (1, 2) || error("Cannot handle $hdims horizontal dimensions") if ClimaComms.device(space) isa ClimaComms.CUDADevice + # FIXME: Avoid allocation of tuple + field_values = tuple(map(f -> Fields.field_values(f), fields)...) nitems = length(out) nthreads, nblocks = Topologies._configure_threadblock(nitems) @cuda always_inline = true threads = (nthreads) blocks = (nblocks) set_interpolated_values_kernel!( @@ -566,20 +649,23 @@ function _set_interpolated_values!( field_values, ) else - for (out_index, h) in enumerate(local_horiz_indices) - out[out_index] = zero(FT) - if hdims == 2 - for j in 1:Nq, i in 1:Nq - out[out_index] += - local_horiz_interpolation_weights[1][out_index, i] * - local_horiz_interpolation_weights[2][out_index, j] * - field_values[i, j, nothing, nothing, h] - end - elseif hdims == 1 - for i in 1:Nq - out[out_index] += - local_horiz_interpolation_weights[1][out_index, i] * - field_values[i, nothing, nothing, nothing, h] + for (field_index, field) in enumerate(fields) + field_values = Fields.field_values(field) + for (out_index, h) in enumerate(local_horiz_indices) + out[field_index, out_index] = zero(FT) + if hdims == 2 + for j in 1:Nq, i in 1:Nq + out[field_index, out_index] += + local_horiz_interpolation_weights[1][out_index, i] * + local_horiz_interpolation_weights[2][out_index, j] * + field_values[i, j, nothing, nothing, h] + end + elseif hdims == 1 + for i in 1:Nq + out[field_index, out_index] += + local_horiz_interpolation_weights[1][out_index, i] * + field_values[i, nothing, nothing, nothing, h] + end end end end @@ -592,21 +678,32 @@ function set_interpolated_values_kernel!( local_horiz_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_fields = length(field_values) - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - index > length(out) && return nothing + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z + + totalThreadsX = gridDim().x * blockDim().x + totalThreadsZ = gridDim().z * blockDim().z - h = local_horiz_indices[index] _, Nq = size(I1) - out[index] = 0 - for j in 1:Nq, i in 1:Nq - out[index] += - I1[index, i] * - I2[index, j] * - field_values[i, j, nothing, nothing, h] + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for k in findex:totalThreadsZ:num_fields + if i ≀ num_horiz && k ≀ num_fields + out[k, i] = 0 + for t in 1:Nq, s in 1:Nq + out[k, i] += + I1[i, t] * + I2[i, s] * + field_values[k][t, s, nothing, nothing, h] + end + end + end end - return nothing end @@ -616,47 +713,56 @@ function set_interpolated_values_kernel!( local_horiz_indices, field_values, ) + # TODO: Check the memory access pattern. This was not optimized and likely inefficient! + num_horiz = length(local_horiz_indices) + num_fields = length(field_values) + + hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x + findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z - index = threadIdx().x + (blockIdx().x - 1) * blockDim().x - index > length(out) && return nothing + totalThreadsX = gridDim().x * blockDim().x + totalThreadsZ = gridDim().z * blockDim().z - h = local_horiz_indices[index] _, Nq = size(I) - out[index] = 0 - for i in 1:Nq - out[index] += - I[index, i] * field_values[i, nothing, nothing, nothing, h] + for i in hindex:totalThreadsX:num_horiz + h = local_horiz_indices[i] + for k in findex:totalThreadsZ:num_fields + if i ≀ num_horiz && k ≀ num_fields + out[k, i] = 0 + for t in 1:Nq, s in 1:Nq + out[k, i] += + I[i, i] * + field_values[k][t, nothing, nothing, nothing, h] + end + end + end end - return nothing end - """ - _apply_mpi_bitmask!(remapper::Remapper) + _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int) Change to local (private) state of the `remapper` by applying the MPI bitmask and reconstructing the correct shape for the interpolated values. Internally, `remapper` performs interpolation on a flat list of points, this function moves points around according to MPI-ownership and the expected output shape. + +`num_fields` is the number of fields that have been processed and have to be moved in the +`interpolated_values`. We assume that it is always the first `num_fields` that have to be moved. """ -function _apply_mpi_bitmask!(remapper::Remapper) - if isnothing(remapper.target_zcoords) - # _interpolated_values[remapper.local_target_hcoords_bitmask] returns a view on - # space we want to write on - remapper._interpolated_values[remapper.local_target_hcoords_bitmask] .= - remapper._local_interpolated_values - else - # interpolated_values is an array of arrays properly ordered according to the bitmask - - # _interpolated_values[remapper.local_target_hcoords_bitmask, :] returns a - # view on space we want to write on - remapper._interpolated_values[ - remapper.local_target_hcoords_bitmask, - :, - ] .= remapper._local_interpolated_values - end +function _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int) + view( + remapper._interpolated_values, + 1:num_fields, + remapper.local_target_hcoords_bitmask, + :, + ) .= view( + remapper._local_interpolated_values, + 1:num_fields, + remapper.colons..., + ) end """ @@ -670,21 +776,60 @@ function _reset_interpolated_values!(remapper::Remapper) end """ - _collect_and_return_interpolated_values!(remapper::Remapper) + _collect_and_return_interpolated_values!(remapper::Remapper, + num_fields::Int) Perform an MPI call to aggregate the interpolated points from all the MPI processes and save the result in the local state of the `remapper`. Only the root process will return the interpolated data. `_collect_and_return_interpolated_values!` is type-unstable and allocates new return arrays. + +`num_fields` is the number of fields that have been interpolated in this batch. """ -function _collect_and_return_interpolated_values!(remapper::Remapper) - ClimaComms.reduce!(remapper.comms_ctx, remapper._interpolated_values, +) +function _collect_and_return_interpolated_values!( + remapper::Remapper, + num_fields::Int, +) + output_array = ClimaComms.reduce( + remapper.comms_ctx, + remapper._interpolated_values[1:num_fields, remapper.colons...], + +, + ) + + maybe_copy_to_cpu = + ClimaComms.device(remapper.comms_ctx) isa ClimaComms.CUDADevice ? + Array : identity + return ClimaComms.iamroot(remapper.comms_ctx) ? - Array(remapper._interpolated_values) : nothing + maybe_copy_to_cpu(output_array) : nothing end -function _collect_interpolated_values!(dest, remapper::Remapper) +function _collect_interpolated_values!( + dest, + remapper::Remapper, + index_field_begin::Int, + index_field_end::Int, +) + + num_fields = 1 + index_field_end - index_field_begin + only_one_field = num_fields == 1 + + if only_one_field + ClimaComms.reduce!( + remapper.comms_ctx, + remapper._interpolated_values[1, remapper.colons...], + dest, + +, + ) + return nothing + end + + # CUDA.jl does not support views very well at the moment. We can only work with + # num_fields = buffer_length + num_fields == remapper.buffer_length || + error("Operation not currently supported") + # MPI.reduce! seems to behave nicely with respect to CPU/GPU. In particular, # if the destination is on the CPU, but the source is on the GPU, the values # are automatically moved. @@ -694,14 +839,32 @@ function _collect_interpolated_values!(dest, remapper::Remapper) dest, +, ) + return nothing end """ - interpolate(remapper::Remapper, field) - interpolate!(dest, remapper::Remapper, field) + batched_ranges(num_fields, buffer_length) -Interpolate the given `field` as prescribed by `remapper`. +Partition the indices from 1 to num_fields in such a way that no range is larger than +buffer_length. +""" +function batched_ranges(num_fields, buffer_length) + return [ + (i * buffer_length + 1):(min((i + 1) * buffer_length, num_fields)) for + i in 0:(div((num_fields - 1), buffer_length)) + ] +end + +""" + interpolate(remapper::Remapper, fields) + interpolate!(dest, remapper::Remapper, fields) + +Interpolate the given `field`(s) as prescribed by `remapper`. + +The optimal number of fields passed is the `buffer_length` of the `remapper`. If +more fields are passed, the `remapper` will batch work with size up to its +`buffer_length`. This call mutates the internal (private) state of the `remapper`. @@ -734,53 +897,108 @@ remapper = Remapper(space, hcoords, zcoords) int1 = interpolate(remapper, field1) int2 = interpolate(remapper, field2) + +# Or +int12 = interpolate(remapper, [field1, field2]) +# With int1 = int12[1, :, :, :] ``` """ -function interpolate(remapper::Remapper, field::T) where {T <: Fields.Field} +function interpolate(remapper::Remapper, fields) - axes(field) == remapper.space || - error("Field is defined on a different space than remapper") + only_one_field = fields isa Fields.Field + if only_one_field + fields = [fields] + end - # Reset interpolated_values. This is needed because we collect distributed results with - # a + reduction. - _reset_interpolated_values!(remapper) - # Perform the interpolations (horizontal and vertical) - _set_interpolated_values!(remapper, field) - # Reshape the output so that it is a nice grid. - _apply_mpi_bitmask!(remapper) - # Finally, we have to send all the _interpolated_values to root and sum them up to - # obtain the final answer. Only the root will contain something useful. This also moves - # the data off the GPU - return _collect_and_return_interpolated_values!(remapper) + for field in fields + axes(field) == remapper.space || + error("Field is defined on a different space than remapper") + end + + index_field_begin, index_field_end = + 1, min(length(fields), remapper.buffer_length) + + # Partition the indices in such a way that nothing is larger than + # buffer_length + index_ranges = batched_ranges(length(fields), remapper.buffer_length) + + interpolated_values = mapreduce(vcat, index_ranges) do range + num_fields = length(range) + + # Reset interpolated_values. This is needed because we collect distributed results + # with a + reduction. + _reset_interpolated_values!(remapper) + # Perform the interpolations (horizontal and vertical) + _set_interpolated_values!( + remapper, + view(fields, index_field_begin:index_field_end), + ) + # Reshape the output so that it is a nice grid. + _apply_mpi_bitmask!(remapper, num_fields) + # Finally, we have to send all the _interpolated_values to root and sum them up to + # obtain the final answer. Only the root will contain something useful. This also + # moves the data off the GPU + return _collect_and_return_interpolated_values!(remapper, num_fields) + end + + # Non-root processes + isnothing(interpolated_values) && return nothing + + return only_one_field ? interpolated_values[begin, remapper.colons...] : + interpolated_values end +# dest has to be allowed to be nothing because interpolation happens only on the root +# process function interpolate!( dest::Union{Nothing, <:AbstractArray}, remapper::Remapper, - field::T, -) where {T <: Fields.Field} - - axes(field) == remapper.space || - error("Field is defined on a different space than remapper") + fields, +) + only_one_field = fields isa Fields.Field + if only_one_field + fields = [fields] + end if !isnothing(dest) # !isnothing(dest) means that this is the root process, in this case, the size have - # to match - size(dest) == size(remapper._interpolated_values) || error( + # to match (ignoring the buffer_length) + dest_size = only_one_field ? size(dest) : size(dest)[2:end] + + dest_size == size(remapper._interpolated_values)[2:end] || error( "Destination array is not compatible with remapper (size mismatch)", ) end + index_field_begin, index_field_end = + 1, min(length(fields), remapper.buffer_length) + + while true + num_fields = 1 + index_field_end - index_field_begin + + # Reset interpolated_values. This is needed because we collect distributed results + # with a + reduction. + _reset_interpolated_values!(remapper) + # Perform the interpolations (horizontal and vertical) + _set_interpolated_values!( + remapper, + view(fields, index_field_begin:index_field_end), + ) + # Reshape the output so that it is a nice grid. + _apply_mpi_bitmask!(remapper, num_fields) + # Finally, we have to send all the _interpolated_values to root and sum them up to + # obtain the final answer. Only the root will contain something useful. This also + # moves the data off the GPU + _collect_interpolated_values!( + dest, + remapper, + index_field_begin, + index_field_end, + ) - # Reset interpolated_values. This is needed because we collect distributed results with - # a + reduction. - _reset_interpolated_values!(remapper) - # Perform the interpolations (horizontal and vertical) - _set_interpolated_values!(remapper, field) - # Reshape the output so that it is a nice grid. - _apply_mpi_bitmask!(remapper) - # Finally, we have to send all the _interpolated_values to root and sum them - # up to obtain the final answer. This also moves the data off the GPU. The - # output is written to the given destination - _collect_interpolated_values!(dest, remapper) + index_field_end != length(fields) || break + index_field_begin = index_field_begin + remapper.buffer_length + index_field_end = + min(length(fields), index_field_end + remapper.buffer_length) + end return nothing end diff --git a/test/Remapping/distributed_remapping.jl b/test/Remapping/distributed_remapping.jl index 1e7bd7d39b..0d842d50dc 100644 --- a/test/Remapping/distributed_remapping.jl +++ b/test/Remapping/distributed_remapping.jl @@ -25,7 +25,18 @@ atexit() do global_logger(prev_logger) end -if !(device isa ClimaComms.CUDADevice) +@testset "Utils" begin + # batched_ranges(num_fields, buffer_length) + @test Remapping.batched_ranges(1, 1) == [1:1] + @test Remapping.batched_ranges(1, 2) == [1:1] + @test Remapping.batched_ranges(2, 2) == [1:2] + @test Remapping.batched_ranges(3, 2) == [1:2, 3:3] +end + +on_gpu = device isa ClimaComms.CUDADevice +broken = true + +if !on_gpu @testset "2D extruded" begin vertdomain = Domains.IntervalDomain( Geometry.ZPoint(0.0), @@ -63,7 +74,12 @@ if !(device isa ClimaComms.CUDADevice) hcoords = [Geometry.XPoint(x) for x in xpts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = Remapping.Remapper( + hv_center_space, + hcoords, + zcoords, + buffer_length = 2, + ) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) @@ -79,6 +95,54 @@ if !(device isa ClimaComms.CUDADevice) @test interp_z[:, end] β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts] end + + # Remapping two fields + interp_xx = Remapping.interpolate(remapper, [coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xx[1, :, :, :] + @test interp_x β‰ˆ interp_xx[2, :, :, :] + end + + # Remapping three fields (more than the buffer length) + interp_xxx = + Remapping.interpolate(remapper, [coords.x, coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xxx[1, :, :, :] + @test interp_x β‰ˆ interp_xxx[2, :, :, :] + @test interp_x β‰ˆ interp_xxx[3, :, :, :] + end + + # Remapping in-place one field + if !broken + dest = zeros(21, 21) + Remapping.interpolate!(dest, remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + end + + # Two fields + dest = zeros(2, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :, :] + @test interp_x β‰ˆ dest[2, :, :, :] + end + + # Three fields (more than buffer length) + if !broken + dest = zeros(3, 21, 21) + Remapping.interpolate!( + dest, + remapper, + [coords.x, coords.x, coords.x], + ) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :, :] + @test interp_x β‰ˆ dest[2, :, :, :] + @test interp_x β‰ˆ dest[3, :, :, :] + end + end end end @@ -122,7 +186,8 @@ end hcoords = [Geometry.XYPoint(x, y) for x in xpts, y in ypts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) @@ -143,9 +208,56 @@ end @test interp_z[:, :, end] β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts, y in ypts] end + + # Remapping two fields + interp_xy = Remapping.interpolate(remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xy[1, :, :, :] + @test interp_y β‰ˆ interp_xy[2, :, :, :] + end + # Remapping three fields (more than the buffer length) + interp_xyx = Remapping.interpolate(remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xyx[1, :, :, :] + @test interp_y β‰ˆ interp_xyx[2, :, :, :] + @test interp_x β‰ˆ interp_xyx[3, :, :, :] + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : + remapper + dest = zeros(21, 21, 21) + Remapping.interpolate!(dest, remapper_1field, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + + # Two fields + dest = zeros(2, 21, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :, :] + @test interp_y β‰ˆ dest[2, :, :, :] + end + + # Three fields (more than buffer length) + if !broken + dest = zeros(3, 21, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :, :] + @test interp_y β‰ˆ dest[2, :, :, :] + @test interp_x β‰ˆ dest[3, :, :, :] + end + end + # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords, nothing) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) @@ -159,6 +271,57 @@ end if ClimaComms.iamroot(context) @test interp_y β‰ˆ [y for x in xpts, y in ypts] end + + # Two fields + interp_xy = Remapping.interpolate(horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_xy[1, :, :] β‰ˆ interp_x + @test interp_xy[2, :, :] β‰ˆ interp_y + end + + # Three fields + interp_xyx = + Remapping.interpolate(horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_xyx[1, :, :] β‰ˆ interp_x + @test interp_xyx[2, :, :] β‰ˆ interp_y + @test interp_xyx[3, :, :] β‰ˆ interp_x + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + if !broken + dest = zeros(21, 21) + Remapping.interpolate!(dest, remapper_1field, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + end + + # Two fields + dest = zeros(2, 21, 21) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :] + @test interp_y β‰ˆ dest[2, :, :] + end + + # Three fields (more than buffer length) + if !broken + dest = zeros(3, 21, 21) + Remapping.interpolate!( + dest, + horiz_remapper, + [coords.x, coords.y, coords.x], + ) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :] + @test interp_y β‰ˆ dest[2, :, :] + @test interp_x β‰ˆ dest[3, :, :] + end + end end @@ -203,7 +366,8 @@ end hcoords = [Geometry.XYPoint(x, y) for x in xpts, y in ypts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) @@ -225,9 +389,55 @@ end [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts, y in ypts] end + # Remapping two fields + interp_xy = Remapping.interpolate(remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xy[1, :, :, :] + @test interp_y β‰ˆ interp_xy[2, :, :, :] + end + # Remapping three fields (more than the buffer length) + interp_xyx = Remapping.interpolate(remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ interp_xyx[1, :, :, :] + @test interp_y β‰ˆ interp_xyx[2, :, :, :] + @test interp_x β‰ˆ interp_xyx[3, :, :, :] + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : + remapper + dest = zeros(21, 21, 21) + Remapping.interpolate!(dest, remapper_1field, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + + # Two fields + dest = zeros(2, 21, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :, :] + @test interp_y β‰ˆ dest[2, :, :, :] + end + + # Three fields (more than buffer length) + if !broken + dest = zeros(3, 21, 21, 21) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :, :] + @test interp_y β‰ˆ dest[2, :, :, :] + @test interp_x β‰ˆ dest[3, :, :, :] + end + end + # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) @@ -241,6 +451,57 @@ end if ClimaComms.iamroot(context) @test interp_y β‰ˆ [y for x in xpts, y in ypts] end + + # Two fields + interp_xy = Remapping.interpolate(horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_xy[1, :, :] β‰ˆ interp_x + @test interp_xy[2, :, :] β‰ˆ interp_y + end + + # Three fields + interp_xyx = + Remapping.interpolate(horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_xyx[1, :, :] β‰ˆ interp_x + @test interp_xyx[2, :, :] β‰ˆ interp_y + @test interp_xyx[3, :, :] β‰ˆ interp_x + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + if !broken + dest = zeros(21, 21) + Remapping.interpolate!(dest, remapper_1field, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest + end + end + + # Two fields + dest = zeros(2, 21, 21) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :] + @test interp_y β‰ˆ dest[2, :, :] + end + + # Three fields (more than buffer length) + if !broken + dest = zeros(3, 21, 21) + Remapping.interpolate!( + dest, + horiz_remapper, + [coords.x, coords.y, coords.x], + ) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[1, :, :] + @test interp_y β‰ˆ dest[2, :, :] + @test interp_x β‰ˆ dest[3, :, :] + end + end end @testset "3D sphere" begin @@ -274,7 +535,8 @@ end [Geometry.LatLongPoint(lat, long) for long in longpts, lat in latpts] zcoords = [Geometry.ZPoint(z) for z in zpts] - remapper = Remapping.Remapper(hv_center_space, hcoords, zcoords) + remapper = + Remapping.Remapper(hv_center_space, hcoords, zcoords, buffer_length = 2) coords = Fields.coordinate_field(hv_center_space) @@ -301,16 +563,67 @@ end [1000.0 * (29 / 30 + 30 / 30) / 2 for x in longpts, y in latpts] end - # Test interpolation in place + # Remapping two fields + interp_long_lat = + Remapping.interpolate(remapper, [sind.(coords.long), sind.(coords.lat)]) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ interp_long_lat[1, :, :, :] + @test interp_sin_lat β‰ˆ interp_long_lat[2, :, :, :] + end + # Remapping three fields (more than the buffer length) + interp_long_lat_long = Remapping.interpolate( + remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ interp_long_lat_long[1, :, :, :] + @test interp_sin_lat β‰ˆ interp_long_lat_long[2, :, :, :] + @test interp_sin_long β‰ˆ interp_long_lat_long[3, :, :, :] + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + remapper_1field = + on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : + remapper dest = zeros(21, 21, 21) - Remapping.interpolate!(dest, remapper, sind.(coords.lat)) + Remapping.interpolate!(dest, remapper_1field, sind.(coords.long)) if ClimaComms.iamroot(context) - @test dest β‰ˆ [sind(y) for x in longpts, y in latpts, z in zpts] rtol = 0.01 + @test interp_sin_long β‰ˆ dest + end + + # Two fields + dest = zeros(2, 21, 21, 21) + Remapping.interpolate!( + dest, + remapper, + [sind.(coords.long), sind.(coords.lat)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[1, :, :, :] + @test interp_sin_lat β‰ˆ dest[2, :, :, :] + end + + # Three fields (more than buffer length) + if !broken + dest = zeros(3, 21, 21, 21) + Remapping.interpolate!( + dest, + remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[1, :, :, :] + @test interp_sin_lat β‰ˆ dest[2, :, :, :] + @test interp_sin_long β‰ˆ dest[3, :, :, :] + end end # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) - horiz_remapper = Remapping.Remapper(horiz_space, hcoords, nothing) + horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) coords = Fields.coordinate_field(horiz_space) @@ -324,4 +637,64 @@ end if ClimaComms.iamroot(context) @test interp_sin_lat β‰ˆ [sind(y) for x in longpts, y in latpts] rtol = 0.01 end + + # Two fields + interp_sin_long_lat = Remapping.interpolate( + horiz_remapper, + [sind.(coords.long), sind.(coords.lat)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long_lat[1, :, :] β‰ˆ interp_sin_long + @test interp_sin_long_lat[2, :, :] β‰ˆ interp_sin_lat + end + + # Three fields + interp_sin_long_lat_long = Remapping.interpolate( + horiz_remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long_lat_long[1, :, :] β‰ˆ interp_sin_long + @test interp_sin_long_lat_long[2, :, :] β‰ˆ interp_sin_lat + @test interp_sin_long_lat_long[3, :, :] β‰ˆ interp_sin_long + end + + # Remapping in-place one field + # + # We have to change remapper for GPU to make sure it works for when have have only one + # field + if !broken + dest = zeros(21, 21) + Remapping.interpolate!(dest, remapper_1field, sind.(coords.long)) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest + end + end + + # Two fields + dest = zeros(2, 21, 21) + Remapping.interpolate!( + dest, + horiz_remapper, + [sind.(coords.long), sind.(coords.lat)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[1, :, :] + @test interp_sin_lat β‰ˆ dest[2, :, :] + end + + # Three fields (more than buffer length) + if !broken + dest = zeros(3, 21, 21) + Remapping.interpolate!( + dest, + horiz_remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[1, :, :] + @test interp_sin_lat β‰ˆ dest[2, :, :] + @test interp_sin_long β‰ˆ dest[3, :, :] + end + end end From a7eb55e9405ef24dedea735fdbd921f0ee55ab82 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Thu, 18 Apr 2024 13:36:38 -0700 Subject: [PATCH 2/3] Move field index to last --- src/Remapping/distributed_remapping.jl | 85 ++++++------ test/Remapping/distributed_remapping.jl | 168 ++++++++++++------------ 2 files changed, 130 insertions(+), 123 deletions(-) diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index 874502285d..ce05bdf6be 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -181,13 +181,13 @@ struct Remapper{ # Scratch space where we save the process-local field value. We keep overwriting this to # avoid extra allocations. Ideally, we wouldn't need this and we would use views for - # everything. This has dimensions (buffer_length, Nq) or (buffer_length, Nq, Nq) + # everything. This has dimensions (Nq, ) or (Nq, Nq, ) # depending if the horizontal space is 1D or 2D. _field_values::T9 # Storage area where the interpolated values are saved. This is meaningful only for the # root process and gets filled by a interpolate call. This has dimensions - # (buffer_length, H, V), where H is the size of target_hcoords and V of target_zcoords. + # (H, V, buffer_length), where H is the size of target_hcoords and V of target_zcoords. # In other words, this is the expected output array. _interpolated_values::T10 @@ -308,9 +308,9 @@ function Remapper( vert_interpolation_weights = nothing vert_bounding_indices = nothing local_interpolated_values = - ArrayType(zeros(FT, (buffer_length, size(local_horiz_indices)...))) + ArrayType(zeros(FT, (size(local_horiz_indices)..., buffer_length))) interpolated_values = ArrayType( - zeros(FT, (buffer_length, size(local_target_hcoords_bitmask)...)), + zeros(FT, (size(local_target_hcoords_bitmask)..., buffer_length)), ) num_dims = num_hdims else @@ -325,9 +325,9 @@ function Remapper( zeros( FT, ( - buffer_length, size(local_horiz_indices)..., length(target_zcoords), + buffer_length, ), ), ) @@ -335,9 +335,9 @@ function Remapper( zeros( FT, ( - buffer_length, size(local_target_hcoords_bitmask)..., length(target_zcoords), + buffer_length, ), ), ) @@ -433,7 +433,7 @@ function set_interpolated_values_cpu_kernel!( I2[out_index, j] * scratch_field_values[i, j] end - out[field_index, out_index, vindex] = tmp + out[out_index, vindex, field_index] = tmp end end end @@ -470,9 +470,9 @@ function set_interpolated_values_kernel!( A, B = vert_interpolation_weights[j] for k in findex:totalThreadsZ:num_fields if i ≀ num_horiz && j ≀ num_vert && k ≀ num_fields - out[k, i, j] = 0 + out[i, j, k] = 0 for t in 1:Nq, s in 1:Nq - out[k, i, j] += + out[i, j, k] += I1[i, t] * I2[i, s] * ( @@ -509,7 +509,7 @@ function set_interpolated_values_kernel!( totalThreadsY = gridDim().y * blockDim().y totalThreadsZ = gridDim().z * blockDim().z - _, Nq = size(I1) + _, Nq = size(I) for i in hindex:totalThreadsX:num_horiz h = local_horiz_indices[i] @@ -518,11 +518,11 @@ function set_interpolated_values_kernel!( A, B = vert_interpolation_weights[j] for k in findex:totalThreadsZ:num_fields if i ≀ num_horiz && j ≀ num_vert && k ≀ num_fields - out[k, i, j] = 0 + out[i, j, k] = 0 for t in 1:Nq - out[k, i, j] += - I1[i, t] * - I2[i, s] * + out[i, j, k] += + I[i, t] * + I[i, s] * ( A * field_values[k][t, nothing, nothing, v_lo, h] + @@ -577,7 +577,7 @@ function set_interpolated_values_cpu_kernel!( for i in 1:Nq tmp += I[out_index, i] * scratch_field_values[i] end - out[field_index, out_index, vindex] = tmp + out[out_index, vindex, field_index] = tmp end end end @@ -652,17 +652,17 @@ function _set_interpolated_values!( for (field_index, field) in enumerate(fields) field_values = Fields.field_values(field) for (out_index, h) in enumerate(local_horiz_indices) - out[field_index, out_index] = zero(FT) + out[out_index, field_index] = zero(FT) if hdims == 2 for j in 1:Nq, i in 1:Nq - out[field_index, out_index] += + out[out_index, field_index] += local_horiz_interpolation_weights[1][out_index, i] * local_horiz_interpolation_weights[2][out_index, j] * field_values[i, j, nothing, nothing, h] end elseif hdims == 1 for i in 1:Nq - out[field_index, out_index] += + out[out_index, field_index] += local_horiz_interpolation_weights[1][out_index, i] * field_values[i, nothing, nothing, nothing, h] end @@ -694,9 +694,9 @@ function set_interpolated_values_kernel!( h = local_horiz_indices[i] for k in findex:totalThreadsZ:num_fields if i ≀ num_horiz && k ≀ num_fields - out[k, i] = 0 + out[i, k] = 0 for t in 1:Nq, s in 1:Nq - out[k, i] += + out[i, k] += I1[i, t] * I2[i, s] * field_values[k][t, s, nothing, nothing, h] @@ -729,9 +729,9 @@ function set_interpolated_values_kernel!( h = local_horiz_indices[i] for k in findex:totalThreadsZ:num_fields if i ≀ num_horiz && k ≀ num_fields - out[k, i] = 0 + out[i, k] = 0 for t in 1:Nq, s in 1:Nq - out[k, i] += + out[i, k] += I[i, i] * field_values[k][t, nothing, nothing, nothing, h] end @@ -753,16 +753,20 @@ around according to MPI-ownership and the expected output shape. `interpolated_values`. We assume that it is always the first `num_fields` that have to be moved. """ function _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int) - view( - remapper._interpolated_values, - 1:num_fields, - remapper.local_target_hcoords_bitmask, - :, - ) .= view( - remapper._local_interpolated_values, - 1:num_fields, - remapper.colons..., - ) + if isnothing(remapper.target_zcoords) + view( + remapper._interpolated_values, + remapper.local_target_hcoords_bitmask, + 1:num_fields, + ) .= view(remapper._local_interpolated_values, :, 1:num_fields) + else + view( + remapper._interpolated_values, + remapper.local_target_hcoords_bitmask, + :, + 1:num_fields, + ) .= view(remapper._local_interpolated_values, :, :, 1:num_fields) + end end """ @@ -793,7 +797,7 @@ function _collect_and_return_interpolated_values!( ) output_array = ClimaComms.reduce( remapper.comms_ctx, - remapper._interpolated_values[1:num_fields, remapper.colons...], + remapper._interpolated_values[remapper.colons..., 1:num_fields], +, ) @@ -818,7 +822,7 @@ function _collect_interpolated_values!( if only_one_field ClimaComms.reduce!( remapper.comms_ctx, - remapper._interpolated_values[1, remapper.colons...], + remapper._interpolated_values[remapper.colons..., begin], dest, +, ) @@ -922,7 +926,9 @@ function interpolate(remapper::Remapper, fields) # buffer_length index_ranges = batched_ranges(length(fields), remapper.buffer_length) - interpolated_values = mapreduce(vcat, index_ranges) do range + cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1) + + interpolated_values = mapreduce(cat_fn, index_ranges) do range num_fields = length(range) # Reset interpolated_values. This is needed because we collect distributed results @@ -938,13 +944,14 @@ function interpolate(remapper::Remapper, fields) # Finally, we have to send all the _interpolated_values to root and sum them up to # obtain the final answer. Only the root will contain something useful. This also # moves the data off the GPU - return _collect_and_return_interpolated_values!(remapper, num_fields) + ret = _collect_and_return_interpolated_values!(remapper, num_fields) + return ret end # Non-root processes isnothing(interpolated_values) && return nothing - return only_one_field ? interpolated_values[begin, remapper.colons...] : + return only_one_field ? interpolated_values[remapper.colons..., begin] : interpolated_values end @@ -963,9 +970,9 @@ function interpolate!( if !isnothing(dest) # !isnothing(dest) means that this is the root process, in this case, the size have # to match (ignoring the buffer_length) - dest_size = only_one_field ? size(dest) : size(dest)[2:end] + dest_size = only_one_field ? size(dest) : size(dest)[1:(end - 1)] - dest_size == size(remapper._interpolated_values)[2:end] || error( + dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error( "Destination array is not compatible with remapper (size mismatch)", ) end diff --git a/test/Remapping/distributed_remapping.jl b/test/Remapping/distributed_remapping.jl index 0d842d50dc..d30ddd944e 100644 --- a/test/Remapping/distributed_remapping.jl +++ b/test/Remapping/distributed_remapping.jl @@ -99,17 +99,17 @@ if !on_gpu # Remapping two fields interp_xx = Remapping.interpolate(remapper, [coords.x, coords.x]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ interp_xx[1, :, :, :] - @test interp_x β‰ˆ interp_xx[2, :, :, :] + @test interp_x β‰ˆ interp_xx[:, :, 1] + @test interp_x β‰ˆ interp_xx[:, :, 2] end # Remapping three fields (more than the buffer length) interp_xxx = Remapping.interpolate(remapper, [coords.x, coords.x, coords.x]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ interp_xxx[1, :, :, :] - @test interp_x β‰ˆ interp_xxx[2, :, :, :] - @test interp_x β‰ˆ interp_xxx[3, :, :, :] + @test interp_x β‰ˆ interp_xxx[:, :, 1] + @test interp_x β‰ˆ interp_xxx[:, :, 2] + @test interp_x β‰ˆ interp_xxx[:, :, 3] end # Remapping in-place one field @@ -122,25 +122,25 @@ if !on_gpu end # Two fields - dest = zeros(2, 21, 21) + dest = zeros(21, 21, 2) Remapping.interpolate!(dest, remapper, [coords.x, coords.x]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :, :] - @test interp_x β‰ˆ dest[2, :, :, :] + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_x β‰ˆ dest[:, :, 2] end # Three fields (more than buffer length) if !broken - dest = zeros(3, 21, 21) + dest = zeros(21, 21, 3) Remapping.interpolate!( dest, remapper, [coords.x, coords.x, coords.x], ) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :, :] - @test interp_x β‰ˆ dest[2, :, :, :] - @test interp_x β‰ˆ dest[3, :, :, :] + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_x β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] end end end @@ -212,15 +212,15 @@ end # Remapping two fields interp_xy = Remapping.interpolate(remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ interp_xy[1, :, :, :] - @test interp_y β‰ˆ interp_xy[2, :, :, :] + @test interp_x β‰ˆ interp_xy[:, :, :, 1] + @test interp_y β‰ˆ interp_xy[:, :, :, 2] end # Remapping three fields (more than the buffer length) interp_xyx = Remapping.interpolate(remapper, [coords.x, coords.y, coords.x]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ interp_xyx[1, :, :, :] - @test interp_y β‰ˆ interp_xyx[2, :, :, :] - @test interp_x β‰ˆ interp_xyx[3, :, :, :] + @test interp_x β‰ˆ interp_xyx[:, :, :, 1] + @test interp_y β‰ˆ interp_xyx[:, :, :, 2] + @test interp_x β‰ˆ interp_xyx[:, :, :, 3] end # Remapping in-place one field @@ -237,21 +237,21 @@ end end # Two fields - dest = zeros(2, 21, 21, 21) + dest = zeros(21, 21, 21, 2) Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :, :] - @test interp_y β‰ˆ dest[2, :, :, :] + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] end # Three fields (more than buffer length) if !broken - dest = zeros(3, 21, 21, 21) + dest = zeros(21, 21, 21, 3) Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :, :] - @test interp_y β‰ˆ dest[2, :, :, :] - @test interp_x β‰ˆ dest[3, :, :, :] + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] + @test interp_x β‰ˆ dest[:, :, :, 3] end end @@ -275,17 +275,17 @@ end # Two fields interp_xy = Remapping.interpolate(horiz_remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) - @test interp_xy[1, :, :] β‰ˆ interp_x - @test interp_xy[2, :, :] β‰ˆ interp_y + @test interp_xy[:, :, 1] β‰ˆ interp_x + @test interp_xy[:, :, 2] β‰ˆ interp_y end # Three fields interp_xyx = Remapping.interpolate(horiz_remapper, [coords.x, coords.y, coords.x]) if ClimaComms.iamroot(context) - @test interp_xyx[1, :, :] β‰ˆ interp_x - @test interp_xyx[2, :, :] β‰ˆ interp_y - @test interp_xyx[3, :, :] β‰ˆ interp_x + @test interp_xyx[:, :, 1] β‰ˆ interp_x + @test interp_xyx[:, :, 2] β‰ˆ interp_y + @test interp_xyx[:, :, 3] β‰ˆ interp_x end # Remapping in-place one field @@ -301,25 +301,25 @@ end end # Two fields - dest = zeros(2, 21, 21) + dest = zeros(21, 21, 2) Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :] - @test interp_y β‰ˆ dest[2, :, :] + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] end # Three fields (more than buffer length) if !broken - dest = zeros(3, 21, 21) + dest = zeros(21, 21, 3) Remapping.interpolate!( dest, horiz_remapper, [coords.x, coords.y, coords.x], ) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :] - @test interp_y β‰ˆ dest[2, :, :] - @test interp_x β‰ˆ dest[3, :, :] + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] end end end @@ -392,15 +392,15 @@ end # Remapping two fields interp_xy = Remapping.interpolate(remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ interp_xy[1, :, :, :] - @test interp_y β‰ˆ interp_xy[2, :, :, :] + @test interp_x β‰ˆ interp_xy[:, :, :, 1] + @test interp_y β‰ˆ interp_xy[:, :, :, 2] end # Remapping three fields (more than the buffer length) interp_xyx = Remapping.interpolate(remapper, [coords.x, coords.y, coords.x]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ interp_xyx[1, :, :, :] - @test interp_y β‰ˆ interp_xyx[2, :, :, :] - @test interp_x β‰ˆ interp_xyx[3, :, :, :] + @test interp_x β‰ˆ interp_xyx[:, :, :, 1] + @test interp_y β‰ˆ interp_xyx[:, :, :, 2] + @test interp_x β‰ˆ interp_xyx[:, :, :, 3] end # Remapping in-place one field @@ -417,21 +417,21 @@ end end # Two fields - dest = zeros(2, 21, 21, 21) + dest = zeros(21, 21, 21, 2) Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :, :] - @test interp_y β‰ˆ dest[2, :, :, :] + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] end # Three fields (more than buffer length) if !broken - dest = zeros(3, 21, 21, 21) + dest = zeros(21, 21, 21, 3) Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :, :] - @test interp_y β‰ˆ dest[2, :, :, :] - @test interp_x β‰ˆ dest[3, :, :, :] + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] + @test interp_x β‰ˆ dest[:, :, :, 3] end end @@ -455,17 +455,17 @@ end # Two fields interp_xy = Remapping.interpolate(horiz_remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) - @test interp_xy[1, :, :] β‰ˆ interp_x - @test interp_xy[2, :, :] β‰ˆ interp_y + @test interp_xy[:, :, 1] β‰ˆ interp_x + @test interp_xy[:, :, 2] β‰ˆ interp_y end # Three fields interp_xyx = Remapping.interpolate(horiz_remapper, [coords.x, coords.y, coords.x]) if ClimaComms.iamroot(context) - @test interp_xyx[1, :, :] β‰ˆ interp_x - @test interp_xyx[2, :, :] β‰ˆ interp_y - @test interp_xyx[3, :, :] β‰ˆ interp_x + @test interp_xyx[:, :, 1] β‰ˆ interp_x + @test interp_xyx[:, :, 2] β‰ˆ interp_y + @test interp_xyx[:, :, 3] β‰ˆ interp_x end # Remapping in-place one field @@ -481,25 +481,25 @@ end end # Two fields - dest = zeros(2, 21, 21) + dest = zeros(21, 21, 2) Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :] - @test interp_y β‰ˆ dest[2, :, :] + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] end # Three fields (more than buffer length) if !broken - dest = zeros(3, 21, 21) + dest = zeros(21, 21, 3) Remapping.interpolate!( dest, horiz_remapper, [coords.x, coords.y, coords.x], ) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[1, :, :] - @test interp_y β‰ˆ dest[2, :, :] - @test interp_x β‰ˆ dest[3, :, :] + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] end end end @@ -567,8 +567,8 @@ end interp_long_lat = Remapping.interpolate(remapper, [sind.(coords.long), sind.(coords.lat)]) if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ interp_long_lat[1, :, :, :] - @test interp_sin_lat β‰ˆ interp_long_lat[2, :, :, :] + @test interp_sin_long β‰ˆ interp_long_lat[:, :, :, 1] + @test interp_sin_lat β‰ˆ interp_long_lat[:, :, :, 2] end # Remapping three fields (more than the buffer length) interp_long_lat_long = Remapping.interpolate( @@ -576,9 +576,9 @@ end [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], ) if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ interp_long_lat_long[1, :, :, :] - @test interp_sin_lat β‰ˆ interp_long_lat_long[2, :, :, :] - @test interp_sin_long β‰ˆ interp_long_lat_long[3, :, :, :] + @test interp_sin_long β‰ˆ interp_long_lat_long[:, :, :, 1] + @test interp_sin_lat β‰ˆ interp_long_lat_long[:, :, :, 2] + @test interp_sin_long β‰ˆ interp_long_lat_long[:, :, :, 3] end # Remapping in-place one field @@ -595,29 +595,29 @@ end end # Two fields - dest = zeros(2, 21, 21, 21) + dest = zeros(21, 21, 21, 2) Remapping.interpolate!( dest, remapper, [sind.(coords.long), sind.(coords.lat)], ) if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ dest[1, :, :, :] - @test interp_sin_lat β‰ˆ dest[2, :, :, :] + @test interp_sin_long β‰ˆ dest[:, :, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, :, 2] end # Three fields (more than buffer length) if !broken - dest = zeros(3, 21, 21, 21) + dest = zeros(21, 21, 21, 3) Remapping.interpolate!( dest, remapper, [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], ) if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ dest[1, :, :, :] - @test interp_sin_lat β‰ˆ dest[2, :, :, :] - @test interp_sin_long β‰ˆ dest[3, :, :, :] + @test interp_sin_long β‰ˆ dest[:, :, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, :, 2] + @test interp_sin_long β‰ˆ dest[:, :, :, 3] end end @@ -644,8 +644,8 @@ end [sind.(coords.long), sind.(coords.lat)], ) if ClimaComms.iamroot(context) - @test interp_sin_long_lat[1, :, :] β‰ˆ interp_sin_long - @test interp_sin_long_lat[2, :, :] β‰ˆ interp_sin_lat + @test interp_sin_long_lat[:, :, 1] β‰ˆ interp_sin_long + @test interp_sin_long_lat[:, :, 2] β‰ˆ interp_sin_lat end # Three fields @@ -654,9 +654,9 @@ end [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], ) if ClimaComms.iamroot(context) - @test interp_sin_long_lat_long[1, :, :] β‰ˆ interp_sin_long - @test interp_sin_long_lat_long[2, :, :] β‰ˆ interp_sin_lat - @test interp_sin_long_lat_long[3, :, :] β‰ˆ interp_sin_long + @test interp_sin_long_lat_long[:, :, 1] β‰ˆ interp_sin_long + @test interp_sin_long_lat_long[:, :, 2] β‰ˆ interp_sin_lat + @test interp_sin_long_lat_long[:, :, 3] β‰ˆ interp_sin_long end # Remapping in-place one field @@ -672,29 +672,29 @@ end end # Two fields - dest = zeros(2, 21, 21) + dest = zeros(21, 21, 2) Remapping.interpolate!( dest, horiz_remapper, [sind.(coords.long), sind.(coords.lat)], ) if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ dest[1, :, :] - @test interp_sin_lat β‰ˆ dest[2, :, :] + @test interp_sin_long β‰ˆ dest[:, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, 2] end # Three fields (more than buffer length) if !broken - dest = zeros(3, 21, 21) + dest = zeros(21, 21, 3) Remapping.interpolate!( dest, horiz_remapper, [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], ) if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ dest[1, :, :] - @test interp_sin_lat β‰ˆ dest[2, :, :] - @test interp_sin_long β‰ˆ dest[3, :, :] + @test interp_sin_long β‰ˆ dest[:, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, 2] + @test interp_sin_long β‰ˆ dest[:, :, 3] end end end From a27735eb2820a02f435d51f5c5855d94b42bffb5 Mon Sep 17 00:00:00 2001 From: Gabriele Bozzola Date: Thu, 18 Apr 2024 13:55:04 -0700 Subject: [PATCH 3/3] Remove all restrictions --- NEWS.md | 7 +- src/Remapping/distributed_remapping.jl | 41 ++-- test/Remapping/distributed_remapping.jl | 274 ++++++++++-------------- 3 files changed, 135 insertions(+), 187 deletions(-) diff --git a/NEWS.md b/NEWS.md index d3e4413d00..98007489e1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,8 +4,11 @@ ClimaCore.jl Release Notes main ------- -- ![][badge-πŸ€–precisionΞ”] `Remapper`s can now process multiple `Field`s at the same time if created with some `buffer_lenght > 1`. - PR ([#1669](https://github.com/CliMA/ClimaCore.jl/pull/1669)) Machine-precision differences are expected. +- ![][badge-πŸ€–precisionΞ”] ![][badge-πŸ’₯breaking] `Remapper`s can now process + multiple `Field`s at the same time if created with some `buffer_lenght > 1`. + PR ([#1669](https://github.com/CliMA/ClimaCore.jl/pull/1669)) + Machine-precision differences are expected. This change is breaking because + remappers now return the same array type as the input field. v0.13.4 ------- diff --git a/src/Remapping/distributed_remapping.jl b/src/Remapping/distributed_remapping.jl index ce05bdf6be..5e0c9d5cc6 100644 --- a/src/Remapping/distributed_remapping.jl +++ b/src/Remapping/distributed_remapping.jl @@ -795,30 +795,21 @@ function _collect_and_return_interpolated_values!( remapper::Remapper, num_fields::Int, ) - output_array = ClimaComms.reduce( + return ClimaComms.reduce( remapper.comms_ctx, remapper._interpolated_values[remapper.colons..., 1:num_fields], +, ) - - maybe_copy_to_cpu = - ClimaComms.device(remapper.comms_ctx) isa ClimaComms.CUDADevice ? - Array : identity - - return ClimaComms.iamroot(remapper.comms_ctx) ? - maybe_copy_to_cpu(output_array) : nothing end function _collect_interpolated_values!( dest, remapper::Remapper, index_field_begin::Int, - index_field_end::Int, + index_field_end::Int; + only_one_field, ) - num_fields = 1 + index_field_end - index_field_begin - only_one_field = num_fields == 1 - if only_one_field ClimaComms.reduce!( remapper.comms_ctx, @@ -829,18 +820,12 @@ function _collect_interpolated_values!( return nothing end - # CUDA.jl does not support views very well at the moment. We can only work with - # num_fields = buffer_length - num_fields == remapper.buffer_length || - error("Operation not currently supported") + num_fields = 1 + index_field_end - index_field_begin - # MPI.reduce! seems to behave nicely with respect to CPU/GPU. In particular, - # if the destination is on the CPU, but the source is on the GPU, the values - # are automatically moved. ClimaComms.reduce!( remapper.comms_ctx, - remapper._interpolated_values, - dest, + view(remapper._interpolated_values, remapper.colons..., 1:num_fields), + view(dest, remapper.colons..., index_field_begin:index_field_end), +, ) @@ -882,6 +867,9 @@ to be defined on the root process and to be `nothing` for the other processes. Note: `interpolate` allocates new arrays and has some internal type-instability, `interpolate!` is non-allocating and type-stable. +When using `interpolate!`, the `dest`ination has to be the same array type as the +device in use (e.g., `CuArray` for CUDA runs). + Example ======== @@ -975,6 +963,14 @@ function interpolate!( dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error( "Destination array is not compatible with remapper (size mismatch)", ) + + expected_array_type = + ClimaComms.array_type(ClimaComms.device(remapper.comms_ctx)) + + found_type = nameof(typeof(dest)) + + dest isa expected_array_type || + error("dest is a $found_type, expected $expected_array_type") end index_field_begin, index_field_end = 1, min(length(fields), remapper.buffer_length) @@ -999,7 +995,8 @@ function interpolate!( dest, remapper, index_field_begin, - index_field_end, + index_field_end; + only_one_field, ) index_field_end != length(fields) || break diff --git a/test/Remapping/distributed_remapping.jl b/test/Remapping/distributed_remapping.jl index d30ddd944e..6ae9dd54f9 100644 --- a/test/Remapping/distributed_remapping.jl +++ b/test/Remapping/distributed_remapping.jl @@ -17,6 +17,7 @@ using ClimaComms const context = ClimaComms.context() const pid, nprocs = ClimaComms.init(context) const device = ClimaComms.device() +ArrayType = ClimaComms.array_type(device) # log output only from root process logger_stream = ClimaComms.iamroot(context) ? stderr : devnull @@ -34,7 +35,7 @@ end end on_gpu = device isa ClimaComms.CUDADevice -broken = true +broken = false if !on_gpu @testset "2D extruded" begin @@ -83,16 +84,16 @@ if !on_gpu interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, z in zpts] + @test Array(interp_x) β‰ˆ [x for x in xpts, z in zpts] end interp_z = Remapping.interpolate(remapper, coords.z) expected_z = [z for x in xpts, z in zpts] if ClimaComms.iamroot(context) - @test interp_z[:, 2:(end - 1)] β‰ˆ expected_z[:, 2:(end - 1)] - @test interp_z[:, 1] β‰ˆ + @test Array(interp_z[:, 2:(end - 1)]) β‰ˆ expected_z[:, 2:(end - 1)] + @test Array(interp_z[:, 1]) β‰ˆ [1000.0 * (0 / 30 + 1 / 30) / 2 for x in xpts] - @test interp_z[:, end] β‰ˆ + @test Array(interp_z[:, end]) β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts] end @@ -113,16 +114,14 @@ if !on_gpu end # Remapping in-place one field - if !broken - dest = zeros(21, 21) - Remapping.interpolate!(dest, remapper, coords.x) - if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest - end + dest = ArrayType(zeros(21, 21)) + Remapping.interpolate!(dest, remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest end # Two fields - dest = zeros(21, 21, 2) + dest = ArrayType(zeros(21, 21, 2)) Remapping.interpolate!(dest, remapper, [coords.x, coords.x]) if ClimaComms.iamroot(context) @test interp_x β‰ˆ dest[:, :, 1] @@ -130,18 +129,12 @@ if !on_gpu end # Three fields (more than buffer length) - if !broken - dest = zeros(21, 21, 3) - Remapping.interpolate!( - dest, - remapper, - [coords.x, coords.x, coords.x], - ) - if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[:, :, 1] - @test interp_x β‰ˆ dest[:, :, 2] - @test interp_x β‰ˆ dest[:, :, 3] - end + dest = ArrayType(zeros(21, 21, 3)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.x, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_x β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] end end end @@ -191,21 +184,21 @@ end interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, y in ypts, z in zpts] + @test Array(interp_x) β‰ˆ [x for x in xpts, y in ypts, z in zpts] end interp_y = Remapping.interpolate(remapper, coords.y) if ClimaComms.iamroot(context) - @test interp_y β‰ˆ [y for x in xpts, y in ypts, z in zpts] + @test Array(interp_y) β‰ˆ [y for x in xpts, y in ypts, z in zpts] end interp_z = Remapping.interpolate(remapper, coords.z) expected_z = [z for x in xpts, y in ypts, z in zpts] if ClimaComms.iamroot(context) - @test interp_z[:, :, 2:(end - 1)] β‰ˆ expected_z[:, :, 2:(end - 1)] - @test interp_z[:, :, 1] β‰ˆ + @test Array(interp_z[:, :, 2:(end - 1)]) β‰ˆ expected_z[:, :, 2:(end - 1)] + @test Array(interp_z[:, :, 1]) β‰ˆ [1000.0 * (0 / 30 + 1 / 30) / 2 for x in xpts, y in ypts] - @test interp_z[:, :, end] β‰ˆ + @test Array(interp_z[:, :, end]) β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts, y in ypts] end @@ -224,20 +217,14 @@ end end # Remapping in-place one field - # - # We have to change remapper for GPU to make sure it works for when have have only one - # field - remapper_1field = - on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : - remapper - dest = zeros(21, 21, 21) - Remapping.interpolate!(dest, remapper_1field, coords.x) + dest = ArrayType(zeros(21, 21, 21)) + Remapping.interpolate!(dest, remapper, coords.x) if ClimaComms.iamroot(context) @test interp_x β‰ˆ dest end # Two fields - dest = zeros(21, 21, 21, 2) + dest = ArrayType(zeros(21, 21, 21, 2)) Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) @test interp_x β‰ˆ dest[:, :, :, 1] @@ -245,16 +232,15 @@ end end # Three fields (more than buffer length) - if !broken - dest = zeros(21, 21, 21, 3) - Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) - if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[:, :, :, 1] - @test interp_y β‰ˆ dest[:, :, :, 2] - @test interp_x β‰ˆ dest[:, :, :, 3] - end + dest = ArrayType(zeros(21, 21, 21, 3)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] + @test interp_x β‰ˆ dest[:, :, :, 3] end + # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) @@ -264,12 +250,12 @@ end interp_x = Remapping.interpolate(horiz_remapper, coords.x) # Only root has the final result if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, y in ypts] + @test Array(interp_x) β‰ˆ [x for x in xpts, y in ypts] end interp_y = Remapping.interpolate(horiz_remapper, coords.y) if ClimaComms.iamroot(context) - @test interp_y β‰ˆ [y for x in xpts, y in ypts] + @test Array(interp_y) β‰ˆ [y for x in xpts, y in ypts] end # Two fields @@ -292,16 +278,15 @@ end # # We have to change remapper for GPU to make sure it works for when have have only one # field - if !broken - dest = zeros(21, 21) - Remapping.interpolate!(dest, remapper_1field, coords.x) - if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest - end + dest = ArrayType(zeros(21, 21)) + Remapping.interpolate!(dest, horiz_remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest end + # Two fields - dest = zeros(21, 21, 2) + dest = ArrayType(zeros(21, 21, 2)) Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) @test interp_x β‰ˆ dest[:, :, 1] @@ -309,19 +294,14 @@ end end # Three fields (more than buffer length) - if !broken - dest = zeros(21, 21, 3) - Remapping.interpolate!( - dest, - horiz_remapper, - [coords.x, coords.y, coords.x], - ) - if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[:, :, 1] - @test interp_y β‰ˆ dest[:, :, 2] - @test interp_x β‰ˆ dest[:, :, 3] - end + dest = ArrayType(zeros(21, 21, 3)) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] end + end @@ -371,21 +351,21 @@ end interp_x = Remapping.interpolate(remapper, coords.x) if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, y in ypts, z in zpts] + @test Array(interp_x) β‰ˆ [x for x in xpts, y in ypts, z in zpts] end interp_y = Remapping.interpolate(remapper, coords.y) if ClimaComms.iamroot(context) - @test interp_y β‰ˆ [y for x in xpts, y in ypts, z in zpts] + @test Array(interp_y) β‰ˆ [y for x in xpts, y in ypts, z in zpts] end interp_z = Remapping.interpolate(remapper, coords.z) expected_z = [z for x in xpts, y in ypts, z in zpts] if ClimaComms.iamroot(context) - @test interp_z[:, :, 2:(end - 1)] β‰ˆ expected_z[:, :, 2:(end - 1)] - @test interp_z[:, :, 1] β‰ˆ + @test Array(interp_z[:, :, 2:(end - 1)]) β‰ˆ expected_z[:, :, 2:(end - 1)] + @test Array(interp_z[:, :, 1]) β‰ˆ [1000.0 * (0 / 30 + 1 / 30) / 2 for x in xpts, y in ypts] - @test interp_z[:, :, end] β‰ˆ + @test Array(interp_z[:, :, end]) β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in xpts, y in ypts] end @@ -404,20 +384,14 @@ end end # Remapping in-place one field - # - # We have to change remapper for GPU to make sure it works for when have have only one - # field - remapper_1field = - on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : - remapper - dest = zeros(21, 21, 21) - Remapping.interpolate!(dest, remapper_1field, coords.x) + dest = ArrayType(zeros(21, 21, 21)) + Remapping.interpolate!(dest, remapper, coords.x) if ClimaComms.iamroot(context) @test interp_x β‰ˆ dest end # Two fields - dest = zeros(21, 21, 21, 2) + dest = ArrayType(zeros(21, 21, 21, 2)) Remapping.interpolate!(dest, remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) @test interp_x β‰ˆ dest[:, :, :, 1] @@ -425,16 +399,15 @@ end end # Three fields (more than buffer length) - if !broken - dest = zeros(21, 21, 21, 3) - Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) - if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[:, :, :, 1] - @test interp_y β‰ˆ dest[:, :, :, 2] - @test interp_x β‰ˆ dest[:, :, :, 3] - end + dest = ArrayType(zeros(21, 21, 21, 3)) + Remapping.interpolate!(dest, remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, :, 1] + @test interp_y β‰ˆ dest[:, :, :, 2] + @test interp_x β‰ˆ dest[:, :, :, 3] end + # Horizontal space horiz_space = Spaces.horizontal_space(hv_center_space) horiz_remapper = Remapping.Remapper(horiz_space, hcoords, buffer_length = 2) @@ -444,12 +417,12 @@ end interp_x = Remapping.interpolate(horiz_remapper, coords.x) # Only root has the final result if ClimaComms.iamroot(context) - @test interp_x β‰ˆ [x for x in xpts, y in ypts] + @test Array(interp_x) β‰ˆ [x for x in xpts, y in ypts] end interp_y = Remapping.interpolate(horiz_remapper, coords.y) if ClimaComms.iamroot(context) - @test interp_y β‰ˆ [y for x in xpts, y in ypts] + @test Array(interp_y) β‰ˆ [y for x in xpts, y in ypts] end # Two fields @@ -469,19 +442,15 @@ end end # Remapping in-place one field - # - # We have to change remapper for GPU to make sure it works for when have have only one - # field - if !broken - dest = zeros(21, 21) - Remapping.interpolate!(dest, remapper_1field, coords.x) - if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest - end + dest = ArrayType(zeros(21, 21)) + Remapping.interpolate!(dest, horiz_remapper, coords.x) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest end + # Two fields - dest = zeros(21, 21, 2) + dest = ArrayType(zeros(21, 21, 2)) Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y]) if ClimaComms.iamroot(context) @test interp_x β‰ˆ dest[:, :, 1] @@ -489,18 +458,12 @@ end end # Three fields (more than buffer length) - if !broken - dest = zeros(21, 21, 3) - Remapping.interpolate!( - dest, - horiz_remapper, - [coords.x, coords.y, coords.x], - ) - if ClimaComms.iamroot(context) - @test interp_x β‰ˆ dest[:, :, 1] - @test interp_y β‰ˆ dest[:, :, 2] - @test interp_x β‰ˆ dest[:, :, 3] - end + dest = ArrayType(zeros(21, 21, 3)) + Remapping.interpolate!(dest, horiz_remapper, [coords.x, coords.y, coords.x]) + if ClimaComms.iamroot(context) + @test interp_x β‰ˆ dest[:, :, 1] + @test interp_y β‰ˆ dest[:, :, 2] + @test interp_x β‰ˆ dest[:, :, 3] end end @@ -543,23 +506,23 @@ end interp_sin_long = Remapping.interpolate(remapper, sind.(coords.long)) # Only root has the final result if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ + @test Array(interp_sin_long) β‰ˆ [sind(x) for x in longpts, y in latpts, z in zpts] rtol = 0.01 end interp_sin_lat = Remapping.interpolate(remapper, sind.(coords.lat)) if ClimaComms.iamroot(context) - @test interp_sin_lat β‰ˆ + @test Array(interp_sin_lat) β‰ˆ [sind(y) for x in longpts, y in latpts, z in zpts] rtol = 0.01 end interp_z = Remapping.interpolate(remapper, coords.z) expected_z = [z for x in longpts, y in latpts, z in zpts] if ClimaComms.iamroot(context) - @test interp_z[:, :, 2:(end - 1)] β‰ˆ expected_z[:, :, 2:(end - 1)] - @test interp_z[:, :, 1] β‰ˆ + @test Array(interp_z[:, :, 2:(end - 1)]) β‰ˆ expected_z[:, :, 2:(end - 1)] + @test Array(interp_z[:, :, 1]) β‰ˆ [1000.0 * (0 / 30 + 1 / 30) / 2 for x in longpts, y in latpts] - @test interp_z[:, :, end] β‰ˆ + @test Array(interp_z[:, :, end]) β‰ˆ [1000.0 * (29 / 30 + 30 / 30) / 2 for x in longpts, y in latpts] end @@ -582,20 +545,14 @@ end end # Remapping in-place one field - # - # We have to change remapper for GPU to make sure it works for when have have only one - # field - remapper_1field = - on_gpu ? Remapping.Remapper(hv_center_space, hcoords, zcoords) : - remapper - dest = zeros(21, 21, 21) - Remapping.interpolate!(dest, remapper_1field, sind.(coords.long)) + dest = ArrayType(zeros(21, 21, 21)) + Remapping.interpolate!(dest, remapper, sind.(coords.long)) if ClimaComms.iamroot(context) @test interp_sin_long β‰ˆ dest end # Two fields - dest = zeros(21, 21, 21, 2) + dest = ArrayType(zeros(21, 21, 21, 2)) Remapping.interpolate!( dest, remapper, @@ -607,18 +564,16 @@ end end # Three fields (more than buffer length) - if !broken - dest = zeros(21, 21, 21, 3) - Remapping.interpolate!( - dest, - remapper, - [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], - ) - if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ dest[:, :, :, 1] - @test interp_sin_lat β‰ˆ dest[:, :, :, 2] - @test interp_sin_long β‰ˆ dest[:, :, :, 3] - end + dest = ArrayType(zeros(21, 21, 21, 3)) + Remapping.interpolate!( + dest, + remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[:, :, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, :, 2] + @test interp_sin_long β‰ˆ dest[:, :, :, 3] end # Horizontal space @@ -630,12 +585,12 @@ end interp_sin_long = Remapping.interpolate(horiz_remapper, sind.(coords.long)) # Only root has the final result if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ [sind(x) for x in longpts, y in latpts] rtol = 0.01 + @test Array(interp_sin_long) β‰ˆ [sind(x) for x in longpts, y in latpts] rtol = 0.01 end interp_sin_lat = Remapping.interpolate(horiz_remapper, sind.(coords.lat)) if ClimaComms.iamroot(context) - @test interp_sin_lat β‰ˆ [sind(y) for x in longpts, y in latpts] rtol = 0.01 + @test Array(interp_sin_lat) β‰ˆ [sind(y) for x in longpts, y in latpts] rtol = 0.01 end # Two fields @@ -660,19 +615,14 @@ end end # Remapping in-place one field - # - # We have to change remapper for GPU to make sure it works for when have have only one - # field - if !broken - dest = zeros(21, 21) - Remapping.interpolate!(dest, remapper_1field, sind.(coords.long)) - if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ dest - end + dest = ArrayType(zeros(21, 21)) + Remapping.interpolate!(dest, horiz_remapper, sind.(coords.long)) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest end # Two fields - dest = zeros(21, 21, 2) + dest = ArrayType(zeros(21, 21, 2)) Remapping.interpolate!( dest, horiz_remapper, @@ -684,17 +634,15 @@ end end # Three fields (more than buffer length) - if !broken - dest = zeros(21, 21, 3) - Remapping.interpolate!( - dest, - horiz_remapper, - [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], - ) - if ClimaComms.iamroot(context) - @test interp_sin_long β‰ˆ dest[:, :, 1] - @test interp_sin_lat β‰ˆ dest[:, :, 2] - @test interp_sin_long β‰ˆ dest[:, :, 3] - end + dest = ArrayType(zeros(21, 21, 3)) + Remapping.interpolate!( + dest, + horiz_remapper, + [sind.(coords.long), sind.(coords.lat), sind.(coords.long)], + ) + if ClimaComms.iamroot(context) + @test interp_sin_long β‰ˆ dest[:, :, 1] + @test interp_sin_lat β‰ˆ dest[:, :, 2] + @test interp_sin_long β‰ˆ dest[:, :, 3] end end