Skip to content

Commit

Permalink
interpolate_array: add support for GPUs
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Oct 30, 2023
1 parent 28bbd23 commit 3696fb6
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 138 deletions.
13 changes: 13 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,25 @@ steps:
command: "srun julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl"
env:
CLIMACOMMS_CONTEXT: "MPI"
CLIMACOMMS_DEVICE: "CPU"
agents:
slurm_ntasks: 2

- label: "Unit: distributed remapping (1 process)"
key: distributed_remapping_1proc
command: "julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl"
env:
CLIMACOMMS_DEVICE: "CPU"

- label: "Unit: distributed remapping with CUDA"
key: distributed_remapping
command: "srun julia --color=yes --check-bounds=yes --project=test test/Remapping/distributed_remapping.jl"
env:
CLIMACOMMS_CONTEXT: "MPI"
CLIMACOMMS_DEVICE: "CUDA"
agents:
slurm_ntasks: 2
slurm_gpus: 1

- label: "Unit: distributed gather"
key: unit_distributed_gather4
Expand Down
1 change: 1 addition & 0 deletions src/Remapping/Remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import ..DataLayouts,
..Hypsography

using ..RecursiveApply
using CUDA

include("interpolate_array.jl")
include("distributed_remapping.jl")
Expand Down
211 changes: 161 additions & 50 deletions src/Remapping/interpolate_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ function interpolate_slab(
)
space = axes(field)
x = zero(eltype(field))
QS = Spaces.quadrature_style(space)
Nq = Spaces.Quadratures.degrees_of_freedom(QS)
Nq = length(I1)

for i in 1:Nq
ij = CartesianIndex((i,))
Expand All @@ -22,22 +21,74 @@ function interpolate_slab(
)
space = axes(field)
x = zero(eltype(field))
QS = Spaces.quadrature_style(space)
Nq = Spaces.Quadratures.degrees_of_freedom(QS)
Nq1, Nq2 = length(I1), length(I2)

for j in 1:Nq, i in 1:Nq
for j in 1:Nq2, i in 1:Nq1
ij = CartesianIndex((i, j))
x += I1[i] * I2[j] * Operators.get_node(space, field, ij, slabidx)
end
return x
end

"""
vertical_indices_ref_coordinate(space, zcoord)
Return the vertical indices of the elements below and above `zcoord`.
Return also the correct reference coordinate `zcoord` for vertical interpolation.
"""
function vertical_indices end

function vertical_indices_ref_coordinate(
space::Spaces.FaceExtrudedFiniteDifferenceSpace,
zcoord,
)
vert_topology = Spaces.vertical_topology(space)
vert_mesh = vert_topology.mesh

velem = Meshes.containing_element(vert_mesh, zcoord)
ξ3, = Meshes.reference_coordinates(vert_mesh, velem, zcoord)
v_lo, v_hi = velem - half, velem + half
return v_lo, v_hi, ξ3
end

function vertical_indices_ref_coordinate(
space::Spaces.CenterExtrudedFiniteDifferenceSpace,
zcoord,
)
vert_topology = Spaces.vertical_topology(space)
vert_mesh = vert_topology.mesh
Nz = Spaces.nlevels(space)

velem = Meshes.containing_element(vert_mesh, zcoord)
ξ3, = Meshes.reference_coordinates(vert_mesh, velem, zcoord)
if ξ3 < 0
if Topologies.isperiodic(Spaces.vertical_topology(space))
v_lo = mod1(velem - 1, Nz)
else
v_lo = max(velem - 1, 1)
end
v_hi = velem
ξ3 = ξ3 + 1
else
v_lo = velem
if Topologies.isperiodic(Spaces.vertical_topology(space))
v_hi = mod1(velem + 1, Nz)
else
v_hi = min(velem + 1, Nz)
end
ξ3 = ξ3 - 1
end
return v_lo, v_hi, ξ3
end

"""
interpolate_slab_level(
field::Fields.Field,
h::Integer,
Is::Tuple,
zcoord,
zcoord;
fill_value = eltype(field)(NaN)
)
Vertically interpolate the given `field` on `zcoord`.
Expand All @@ -49,47 +100,111 @@ element in a column, no interpolation is performed and the value at the cell cen
returned. Effectively, this means that the interpolation is first-order accurate across the
column, but zeroth-order accurate close to the boundaries.
Return `fill_value` when the vertical coordinate is negative.
"""
function interpolate_slab_level(
function interpolate_slab_level!(
output_array,
field::Fields.Field,
h::Integer,
Is::Tuple,
zcoord,
zpts;
fill_value = eltype(field)(NaN),
)
space = axes(field)
vert_topology = Spaces.vertical_topology(space)
vert_mesh = vert_topology.mesh
Nz = Spaces.nlevels(space)
device = ClimaComms.device(field)
interpolate_slab_level!(
output_array,
field,
h,
Is,
zpts,
device;
fill_value,
)
end

velem = Meshes.containing_element(vert_mesh, zcoord)
ξ3, = Meshes.reference_coordinates(vert_mesh, velem, zcoord)
if space isa Spaces.FaceExtrudedFiniteDifferenceSpace
v_lo = velem - half
v_hi = velem + half
elseif space isa Spaces.CenterExtrudedFiniteDifferenceSpace
if ξ3 < 0
if Topologies.isperiodic(Spaces.vertical_topology(space))
v_lo = mod1(velem - 1, Nz)
else
v_lo = max(velem - 1, 1)
end
v_hi = velem
ξ3 = ξ3 + 1
else
v_lo = velem
if Topologies.isperiodic(Spaces.vertical_topology(space))
v_hi = mod1(velem + 1, Nz)
else
v_hi = min(velem + 1, Nz)
end
ξ3 = ξ3 - 1
end
function interpolate_slab_level!(
output_array,
field::Fields.Field,
h::Integer,
Is::Tuple,
zpts,
device::ClimaComms.AbstractCPUDevice;
fill_value = Spaces.undertype(axes(field))(NaN),
)
output_array .= map(zpts) do (zcoord)
zcoord.z < 0 && return fill_value

v_lo, v_hi, ξ3 =
vertical_indices_ref_coordinate(axes(field), zcoord)

f_lo = interpolate_slab(field, Fields.SlabIndex(v_lo, h), Is)
f_hi = interpolate_slab(field, Fields.SlabIndex(v_hi, h), Is)
return ((1 - ξ3) * f_lo + (1 + ξ3) * f_hi) / 2
end
f_lo = interpolate_slab(field, Fields.SlabIndex(v_lo, h), Is)
f_hi = interpolate_slab(field, Fields.SlabIndex(v_hi, h), Is)
return ((1 - ξ3) * f_lo + (1 + ξ3) * f_hi) / 2
end

function interpolate_slab_level!(
output_array,
field::Fields.Field,
h::Integer,
Is::Tuple,
zpts,
device::ClimaComms.CUDADevice;
fill_value = Spaces.undertype(axes(field))(NaN),
)
# We have to deal with topography and NaNs. For that, we select the points that have z
# >= 0 (above the surface) and interpolate only those on the GPU. Then, we fill the
# output array with fill_value and overwrite only those values that we computed with
# interpolation. This is a simple way to avoid having branching on the GPU to check if
# z>0.

positive_zcoords_indices = [z.z >= 0 for z in zpts]

vertical_indices_ref_coordinates = CuArray([
vertical_indices_ref_coordinate(axes(field), zcoord) for
zcoord in zpts[positive_zcoords_indices]
])

output_cuarray = CuArray(
zeros(
Spaces.undertype(axes(field)),
length(vertical_indices_ref_coordinates),
),
)

nitems = length(zpts)
nthreads, nblocks = Spaces._configure_threadblock(nitems)
@cuda threads = (nthreads) blocks = (nblocks) interpolate_slab_level_kernel!(
output_cuarray,
field,
vertical_indices_ref_coordinates,
h,
Is,
)
output_array .= fill_value
output_array[positive_zcoords_indices] .= Array(output_cuarray)
end

function interpolate_slab_level_kernel!(
output_array,
field,
vidx_ref_coordinates,
h,
Is,
)
index = threadIdx().x + (blockIdx().x - 1) * blockDim().x
if index <= length(output_array)
v_lo, v_hi, ξ3 = vidx_ref_coordinates[index]

f_lo = interpolate_slab(field, Fields.SlabIndex(v_lo, h), Is)
f_hi = interpolate_slab(field, Fields.SlabIndex(v_hi, h), Is)
output_array[index] = ((1 - ξ3) * f_lo + (1 + ξ3) * f_hi) / 2
end
return nothing
end


"""
interpolate_array(field, xpts, ypts)
interpolate_array(field, xpts, ypts, zpts)
Expand Down Expand Up @@ -136,9 +251,7 @@ function interpolate_array(
weights = interpolation_weights(horz_mesh, hcoord, quad_points)
h = helem

for (iz, zcoord) in enumerate(zpts)
array[ix, iz] = interpolate_slab_level(field, h, weights, zcoord)
end
interpolate_slab_level!(view(array, ix, :), field, h, weights, zpts)
end
return array
end
Expand All @@ -159,6 +272,7 @@ function interpolate_array(
array = zeros(T, length(xpts), length(ypts), length(zpts))

FT = Spaces.undertype(space)

for (iy, ycoord) in enumerate(ypts), (ix, xcoord) in enumerate(xpts)
hcoord = Geometry.product_coordinates(xcoord, ycoord)
helem = Meshes.containing_element(horz_mesh, hcoord)
Expand All @@ -168,10 +282,7 @@ function interpolate_array(
gidx = horz_topology.orderindex[helem]
h = gidx

for (iz, zcoord) in enumerate(zpts)
array[ix, iy, iz] =
interpolate_slab_level(field, h, weights, zcoord)
end
interpolate_slab_level!(view(array, ix, iy, :), field, h, weights, zpts)
end
return array
end
Expand Down Expand Up @@ -233,7 +344,6 @@ function interpolate_column(
physical_z = false,
)
space = axes(field)
FT = Spaces.undertype(space)

# When we don't have hypsography, there is no notion of "interpolating hypsography". In
# this case, the reference vertical points coincide with the physical ones. Setting
Expand Down Expand Up @@ -263,8 +373,9 @@ function interpolate_column(
zpts_ref = zpts
end

return [
z.z >= 0 ? interpolate_slab_level(field, gidx, weights, z) : FT(NaN) for
z in zpts_ref
]
output_array = zeros(Spaces.undertype(space), length(zpts))

interpolate_slab_level!(output_array, field, gidx, weights, zpts_ref)

return output_array
end
Loading

0 comments on commit 3696fb6

Please sign in to comment.