Skip to content

Commit

Permalink
Move field index to last
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed Apr 18, 2024
1 parent 38ffafd commit a7eb55e
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 123 deletions.
85 changes: 46 additions & 39 deletions src/Remapping/distributed_remapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,13 @@ struct Remapper{

# Scratch space where we save the process-local field value. We keep overwriting this to
# avoid extra allocations. Ideally, we wouldn't need this and we would use views for
# everything. This has dimensions (buffer_length, Nq) or (buffer_length, Nq, Nq)
# everything. This has dimensions (Nq, ) or (Nq, Nq, )
# depending if the horizontal space is 1D or 2D.
_field_values::T9

# Storage area where the interpolated values are saved. This is meaningful only for the
# root process and gets filled by a interpolate call. This has dimensions
# (buffer_length, H, V), where H is the size of target_hcoords and V of target_zcoords.
# (H, V, buffer_length), where H is the size of target_hcoords and V of target_zcoords.
# In other words, this is the expected output array.
_interpolated_values::T10

Expand Down Expand Up @@ -308,9 +308,9 @@ function Remapper(
vert_interpolation_weights = nothing
vert_bounding_indices = nothing
local_interpolated_values =
ArrayType(zeros(FT, (buffer_length, size(local_horiz_indices)...)))
ArrayType(zeros(FT, (size(local_horiz_indices)..., buffer_length)))
interpolated_values = ArrayType(
zeros(FT, (buffer_length, size(local_target_hcoords_bitmask)...)),
zeros(FT, (size(local_target_hcoords_bitmask)..., buffer_length)),
)
num_dims = num_hdims
else
Expand All @@ -325,19 +325,19 @@ function Remapper(
zeros(
FT,
(
buffer_length,
size(local_horiz_indices)...,
length(target_zcoords),
buffer_length,
),
),
)
interpolated_values = ArrayType(
zeros(
FT,
(
buffer_length,
size(local_target_hcoords_bitmask)...,
length(target_zcoords),
buffer_length,
),
),
)
Expand Down Expand Up @@ -433,7 +433,7 @@ function set_interpolated_values_cpu_kernel!(
I2[out_index, j] *
scratch_field_values[i, j]
end
out[field_index, out_index, vindex] = tmp
out[out_index, vindex, field_index] = tmp
end
end
end
Expand Down Expand Up @@ -470,9 +470,9 @@ function set_interpolated_values_kernel!(
A, B = vert_interpolation_weights[j]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && j num_vert && k num_fields
out[k, i, j] = 0
out[i, j, k] = 0
for t in 1:Nq, s in 1:Nq
out[k, i, j] +=
out[i, j, k] +=
I1[i, t] *
I2[i, s] *
(
Expand Down Expand Up @@ -509,7 +509,7 @@ function set_interpolated_values_kernel!(
totalThreadsY = gridDim().y * blockDim().y
totalThreadsZ = gridDim().z * blockDim().z

_, Nq = size(I1)
_, Nq = size(I)

for i in hindex:totalThreadsX:num_horiz
h = local_horiz_indices[i]
Expand All @@ -518,11 +518,11 @@ function set_interpolated_values_kernel!(
A, B = vert_interpolation_weights[j]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && j num_vert && k num_fields
out[k, i, j] = 0
out[i, j, k] = 0
for t in 1:Nq
out[k, i, j] +=
I1[i, t] *
I2[i, s] *
out[i, j, k] +=
I[i, t] *
I[i, s] *
(
A *
field_values[k][t, nothing, nothing, v_lo, h] +
Expand Down Expand Up @@ -577,7 +577,7 @@ function set_interpolated_values_cpu_kernel!(
for i in 1:Nq
tmp += I[out_index, i] * scratch_field_values[i]
end
out[field_index, out_index, vindex] = tmp
out[out_index, vindex, field_index] = tmp
end
end
end
Expand Down Expand Up @@ -652,17 +652,17 @@ function _set_interpolated_values!(
for (field_index, field) in enumerate(fields)
field_values = Fields.field_values(field)
for (out_index, h) in enumerate(local_horiz_indices)
out[field_index, out_index] = zero(FT)
out[out_index, field_index] = zero(FT)
if hdims == 2
for j in 1:Nq, i in 1:Nq
out[field_index, out_index] +=
out[out_index, field_index] +=
local_horiz_interpolation_weights[1][out_index, i] *
local_horiz_interpolation_weights[2][out_index, j] *
field_values[i, j, nothing, nothing, h]
end
elseif hdims == 1
for i in 1:Nq
out[field_index, out_index] +=
out[out_index, field_index] +=
local_horiz_interpolation_weights[1][out_index, i] *
field_values[i, nothing, nothing, nothing, h]
end
Expand Down Expand Up @@ -694,9 +694,9 @@ function set_interpolated_values_kernel!(
h = local_horiz_indices[i]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && k num_fields
out[k, i] = 0
out[i, k] = 0
for t in 1:Nq, s in 1:Nq
out[k, i] +=
out[i, k] +=
I1[i, t] *
I2[i, s] *
field_values[k][t, s, nothing, nothing, h]
Expand Down Expand Up @@ -729,9 +729,9 @@ function set_interpolated_values_kernel!(
h = local_horiz_indices[i]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && k num_fields
out[k, i] = 0
out[i, k] = 0
for t in 1:Nq, s in 1:Nq
out[k, i] +=
out[i, k] +=
I[i, i] *
field_values[k][t, nothing, nothing, nothing, h]
end
Expand All @@ -753,16 +753,20 @@ around according to MPI-ownership and the expected output shape.
`interpolated_values`. We assume that it is always the first `num_fields` that have to be moved.
"""
function _apply_mpi_bitmask!(remapper::Remapper, num_fields::Int)
view(
remapper._interpolated_values,
1:num_fields,
remapper.local_target_hcoords_bitmask,
:,
) .= view(
remapper._local_interpolated_values,
1:num_fields,
remapper.colons...,
)
if isnothing(remapper.target_zcoords)
view(
remapper._interpolated_values,
remapper.local_target_hcoords_bitmask,
1:num_fields,
) .= view(remapper._local_interpolated_values, :, 1:num_fields)
else
view(
remapper._interpolated_values,
remapper.local_target_hcoords_bitmask,
:,
1:num_fields,
) .= view(remapper._local_interpolated_values, :, :, 1:num_fields)
end
end

"""
Expand Down Expand Up @@ -793,7 +797,7 @@ function _collect_and_return_interpolated_values!(
)
output_array = ClimaComms.reduce(
remapper.comms_ctx,
remapper._interpolated_values[1:num_fields, remapper.colons...],
remapper._interpolated_values[remapper.colons..., 1:num_fields],
+,
)

Expand All @@ -818,7 +822,7 @@ function _collect_interpolated_values!(
if only_one_field
ClimaComms.reduce!(
remapper.comms_ctx,
remapper._interpolated_values[1, remapper.colons...],
remapper._interpolated_values[remapper.colons..., begin],
dest,
+,
)
Expand Down Expand Up @@ -922,7 +926,9 @@ function interpolate(remapper::Remapper, fields)
# buffer_length
index_ranges = batched_ranges(length(fields), remapper.buffer_length)

interpolated_values = mapreduce(vcat, index_ranges) do range
cat_fn = (l...) -> cat(l..., dims = length(remapper.colons) + 1)

interpolated_values = mapreduce(cat_fn, index_ranges) do range
num_fields = length(range)

# Reset interpolated_values. This is needed because we collect distributed results
Expand All @@ -938,13 +944,14 @@ function interpolate(remapper::Remapper, fields)
# Finally, we have to send all the _interpolated_values to root and sum them up to
# obtain the final answer. Only the root will contain something useful. This also
# moves the data off the GPU
return _collect_and_return_interpolated_values!(remapper, num_fields)
ret = _collect_and_return_interpolated_values!(remapper, num_fields)
return ret
end

# Non-root processes
isnothing(interpolated_values) && return nothing

return only_one_field ? interpolated_values[begin, remapper.colons...] :
return only_one_field ? interpolated_values[remapper.colons..., begin] :
interpolated_values
end

Expand All @@ -963,9 +970,9 @@ function interpolate!(
if !isnothing(dest)
# !isnothing(dest) means that this is the root process, in this case, the size have
# to match (ignoring the buffer_length)
dest_size = only_one_field ? size(dest) : size(dest)[2:end]
dest_size = only_one_field ? size(dest) : size(dest)[1:(end - 1)]

dest_size == size(remapper._interpolated_values)[2:end] || error(
dest_size == size(remapper._interpolated_values)[1:(end - 1)] || error(
"Destination array is not compatible with remapper (size mismatch)",
)
end
Expand Down
Loading

0 comments on commit a7eb55e

Please sign in to comment.