Skip to content

Commit

Permalink
Merge pull request #1843 from CliMA/ck/topo_tests
Browse files Browse the repository at this point in the history
Replace `_get_idx` w/ `CartesianIndices/LinearIndices`
  • Loading branch information
charleskawczynski authored Jun 24, 2024
2 parents 5e9d10d + 18247d6 commit 5f23521
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 97 deletions.
1 change: 1 addition & 0 deletions ext/ClimaCoreCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using CUDA: threadIdx, blockIdx, blockDim
import StaticArrays: SVector, SMatrix, SArray
import ClimaCore.DataLayouts: slab, column
import ClimaCore.Utilities: half
import ClimaCore.Utilities: cart_ind, linear_ind
import ClimaCore.RecursiveApply:
, , , radd, rmul, rsub, rdiv, rmap, rzero, rmin, rmax

Expand Down
2 changes: 1 addition & 1 deletion ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ function copyto_stencil_kernel!(
gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x
if gid Nv * Nq * Nq * Nh
(li, lw, rw, ri) = bds
(v, i, j, h) = Topologies._get_idx((Nv, Nq, Nq, Nh), gid)
(v, i, j, h) = cart_ind((Nv, Nq, Nq, Nh), gid).I
hidx = (i, j, h)
idx = v - 1 + li
window =
Expand Down
6 changes: 3 additions & 3 deletions ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function column_integral_definite_kernel!(
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶜfield))
if idx <= Ni * Nj * Nh
i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx)
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
∫field_column = Spaces.column(∫field, i, j, h)
ᶜfield_column = Spaces.column(ᶜfield, i, j, h)
_column_integral_definite!(∫field_column, ᶜfield_column)
Expand All @@ -52,7 +52,7 @@ function column_integral_indefinite_kernel!(
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶜfield))
if idx <= Ni * Nj * Nh
i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx)
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
ᶠ∫field_column = Spaces.column(ᶠ∫field, i, j, h)
ᶜfield_column = Spaces.column(ᶜfield, i, j, h)
_column_integral_indefinite!(ᶠ∫field_column, ᶜfield_column)
Expand Down Expand Up @@ -117,7 +117,7 @@ function column_mapreduce_kernel_extruded!(
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(reduced_field))
if idx <= Ni * Nj * Nh
i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx)
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
reduced_field_column = Spaces.column(reduced_field, i, j, h)
field_columns = map(field -> Spaces.column(field, i, j, h), fields)
_column_mapreduce!(fn, op, reduced_field_column, field_columns...)
Expand Down
2 changes: 1 addition & 1 deletion ext/cuda/operators_thomas_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function thomas_algorithm_kernel!(
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
if idx <= Ni * Nj * Nh
i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx)
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
thomas_algorithm!(Spaces.column(A, i, j, h), Spaces.column(b, i, j, h))
end
return nothing
Expand Down
60 changes: 25 additions & 35 deletions ext/cuda/topologies_dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ function dss_load_perimeter_data_kernel!(
sized = (nlevels, Nq, Nq, nfidx, nelems) # size of data

if gidx prod(sizep)
(level, p, fidx, elem) = Topologies._get_idx(sizep, gidx)
(level, p, fidx, elem) = cart_ind(sizep, gidx).I
(ip, jp) = perimeter[p]
data_idx = Topologies._get_idx(sized, (level, ip, jp, fidx, elem))
data_idx = linear_ind(sized, (level, ip, jp, fidx, elem))
pperimeter_data[level, p, fidx, elem] = pdata[data_idx]
end
return nothing
Expand Down Expand Up @@ -86,9 +86,9 @@ function dss_unload_perimeter_data_kernel!(
sized = (nlevels, Nq, Nq, nfidx, nelems) # size of data

if gidx prod(sizep)
(level, p, fidx, elem) = Topologies._get_idx(sizep, gidx)
(level, p, fidx, elem) = cart_ind(sizep, gidx).I
(ip, jp) = perimeter[p]
data_idx = Topologies._get_idx(sized, (level, ip, jp, fidx, elem))
data_idx = linear_ind(sized, (level, ip, jp, fidx, elem))
pdata[data_idx] = pperimeter_data[level, p, fidx, elem]
end
return nothing
Expand Down Expand Up @@ -139,7 +139,7 @@ function dss_local_kernel!(
(nlevels, nperimeter, nfidx, _) = size(pperimeter_data)
if gidx nlevels * nfidx * nlocalvertices # local vertices
sizev = (nlevels, nfidx, nlocalvertices)
(level, fidx, vertexid) = Topologies._get_idx(sizev, gidx)
(level, fidx, vertexid) = cart_ind(sizev, gidx).I
sum_data = FT(0)
st, en =
local_vertex_offset[vertexid], local_vertex_offset[vertexid + 1]
Expand All @@ -157,7 +157,7 @@ function dss_local_kernel!(
nfacedof = div(nperimeter - 4, 4)
sizef = (nlevels, nfidx, nlocalfaces)
(level, fidx, faceid) =
Topologies._get_idx(sizef, gidx - nlevels * nfidx * nlocalvertices)
cart_ind(sizef, gidx - nlevels * nfidx * nlocalvertices).I
(lidx1, face1, lidx2, face2, reversed) = interior_faces[faceid]
(first1, inc1, last1) =
Topologies.perimeter_face_indices_cuda(face1, nfacedof, false)
Expand Down Expand Up @@ -247,21 +247,18 @@ function dss_transform_kernel!(
sizet_wt = (Nq, Nq, 1, nelems)
sizet_metric = (nlevels, Nq, Nq, nmetric, nelems)

(level, p, localelemno) = Topologies._get_idx(sizet, gidx)
(level, p, localelemno) = cart_ind(sizet, gidx).I
elem = localelems[localelemno]
(ip, jp) = perimeter[p]

weight = pweight[Topologies._get_idx(sizet_wt, (ip, jp, 1, elem))]
weight = pweight[linear_ind(sizet_wt, (ip, jp, 1, elem))]
for fidx in scalarfidx
data_idx =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
pperimeter_data[level, p, fidx, elem] = pdata[data_idx] * weight
end
for fidx in covariant12fidx
data_idx1 =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx + 1, elem))
data_idx1 = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 = linear_ind(sizet_data, (level, ip, jp, fidx + 1, elem))
(idx11, idx12, idx21, idx22) =
Topologies._get_idx_metric(sizet_metric, (level, ip, jp, elem))
pperimeter_data[level, p, fidx, elem] =
Expand All @@ -276,10 +273,8 @@ function dss_transform_kernel!(
) * weight
end
for fidx in contravariant12fidx
data_idx1 =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx + 1, elem))
data_idx1 = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 = linear_ind(sizet_data, (level, ip, jp, fidx + 1, elem))
(idx11, idx12, idx21, idx22) =
Topologies._get_idx_metric(sizet_metric, (level, ip, jp, elem))
pperimeter_data[level, p, fidx, elem] =
Expand Down Expand Up @@ -363,19 +358,16 @@ function dss_untransform_kernel!(
sizet_wt = (Nq, Nq, 1, nelems)
sizet_metric = (nlevels, Nq, Nq, nmetric, nelems)

(level, p, localelemno) = Topologies._get_idx(sizet, gidx)
(level, p, localelemno) = cart_ind(sizet, gidx).I
elem = localelems[localelemno]
ip, jp = perimeter[p]
for fidx in scalarfidx
data_idx =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
pdata[data_idx] = pperimeter_data[level, p, fidx, elem]
end
for fidx in covariant12fidx
data_idx1 =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx + 1, elem))
data_idx1 = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 = linear_ind(sizet_data, (level, ip, jp, fidx + 1, elem))
(idx11, idx12, idx21, idx22) =
Topologies._get_idx_metric(sizet_metric, (level, ip, jp, elem))
pdata[data_idx1] =
Expand All @@ -386,10 +378,8 @@ function dss_untransform_kernel!(
p∂x∂ξ[idx22] * pperimeter_data[level, p, fidx + 1, elem]
end
for fidx in contravariant12fidx
data_idx1 =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 =
Topologies._get_idx(sizet_data, (level, ip, jp, fidx + 1, elem))
data_idx1 = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 = linear_ind(sizet_data, (level, ip, jp, fidx + 1, elem))
(idx11, idx12, idx21, idx22) =
Topologies._get_idx_metric(sizet_metric, (level, ip, jp, elem))
pdata[data_idx1] =
Expand Down Expand Up @@ -445,7 +435,7 @@ function dss_local_ghost_kernel!(
nghostvertices = length(ghost_vertex_offset) - 1
if gidx nlevels * nfidx * nghostvertices
sizev = (nlevels, nfidx, nghostvertices)
(level, fidx, vertexid) = Topologies._get_idx(sizev, gidx)
(level, fidx, vertexid) = cart_ind(sizev, gidx).I
sum_data = FT(0)
st, en =
ghost_vertex_offset[vertexid], ghost_vertex_offset[vertexid + 1]
Expand Down Expand Up @@ -506,8 +496,8 @@ function fill_send_buffer_kernel!(
sizet = (nlevels, nfid, nsend)
#if gidx ≤ nsend * nlevels * nfid
if gidx nlevels * nfid * nsend
#(isend, level, fidx) = Topologies._get_idx(sizet, gidx)
(level, fidx, isend) = Topologies._get_idx(sizet, gidx)
#(isend, level, fidx) = cart_ind(sizet, gidx).I
(level, fidx, isend) = cart_ind(sizet, gidx).I
lidx = send_buf_idx[isend, 1]
ip = send_buf_idx[isend, 2]
idx = level + ((fidx - 1) + (isend - 1) * nfid) * nlevels
Expand Down Expand Up @@ -551,8 +541,8 @@ function load_from_recv_buffer_kernel!(
sizet = (nlevels, nfid, nrecv)
#if gidx ≤ nrecv * nlevels * nfid
if gidx nlevels * nfid * nrecv
#(irecv, level, fidx) = Topologies._get_idx(sizet, gidx)
(level, fidx, irecv) = Topologies._get_idx(sizet, gidx)
#(irecv, level, fidx) = cart_ind(sizet, gidx).I
(level, fidx, irecv) = cart_ind(sizet, gidx).I
lidx = recv_buf_idx[irecv, 1]
ip = recv_buf_idx[irecv, 2]
idx = level + ((fidx - 1) + (irecv - 1) * nfid) * nlevels
Expand Down Expand Up @@ -605,7 +595,7 @@ function dss_ghost_kernel!(

if gidx nlevels * nfidx * nghostvertices
(level, fidx, ghostvertexidx) =
Topologies._get_idx((nlevels, nfidx, nghostvertices), gidx)
cart_ind((nlevels, nfidx, nghostvertices), gidx).I
idxresult, lvertresult = repr_ghost_vertex[ghostvertexidx]
ipresult = perimeter_vertex_node_index(lvertresult)
result = pperimeter_data[level, ipresult, fidx, idxresult]
Expand Down
2 changes: 1 addition & 1 deletion src/Topologies/Topologies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using DocStringExtensions
import ClimaComms, Adapt

import ..ClimaCore
import ..Utilities: Cache
import ..Utilities: Cache, cart_ind, linear_ind
import ..Geometry
import ..Domains: Domains, coordinate_type
import ..Meshes: Meshes, domain, coordinates
Expand Down
27 changes: 14 additions & 13 deletions src/Topologies/dss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,17 @@ function dss_transform!(

@inbounds for elem in localelems
for (p, (ip, jp)) in enumerate(perimeter)
pw = pweight[_get_idx(sizet_wt, (ip, jp, 1, elem))]
pw = pweight[linear_ind(sizet_wt, (ip, jp, 1, elem))]

for fidx in scalarfidx, level in 1:nlevels
data_idx = _get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
pperimeter_data[level, p, fidx, elem] = pdata[data_idx] * pw
end

for fidx in covariant12fidx, level in 1:nlevels
data_idx1 = _get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx1 = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 =
_get_idx(sizet_data, (level, ip, jp, fidx + 1, elem))
linear_ind(sizet_data, (level, ip, jp, fidx + 1, elem))
(idx11, idx12, idx21, idx22) =
_get_idx_metric(sizet_metric, (level, ip, jp, elem))
pperimeter_data[level, p, fidx, elem] =
Expand All @@ -393,9 +393,9 @@ function dss_transform!(
end

for fidx in contravariant12fidx, level in 1:nlevels
data_idx1 = _get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx1 = linear_ind(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 =
_get_idx(sizet_data, (level, ip, jp, fidx + 1, elem))
linear_ind(sizet_data, (level, ip, jp, fidx + 1, elem))
(idx11, idx12, idx21, idx22) =
_get_idx_metric(sizet_metric, (level, ip, jp, elem))
pperimeter_data[level, p, fidx, elem] =
Expand Down Expand Up @@ -467,16 +467,17 @@ function dss_untransform!(
for (p, (ip, jp)) in enumerate(perimeter)
for fidx in scalarfidx
for level in 1:nlevels
data_idx = _get_idx(sizet_data, (level, ip, jp, fidx, elem))
data_idx =
linear_ind(sizet_data, (level, ip, jp, fidx, elem))
pdata[data_idx] = pperimeter_data[level, p, fidx, elem]
end
end
for fidx in covariant12fidx
for level in 1:nlevels
data_idx1 =
_get_idx(sizet_data, (level, ip, jp, fidx, elem))
linear_ind(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 =
_get_idx(sizet_data, (level, ip, jp, fidx + 1, elem))
linear_ind(sizet_data, (level, ip, jp, fidx + 1, elem))
(idx11, idx12, idx21, idx22) =
_get_idx_metric(sizet_metric, (level, ip, jp, elem))
pdata[data_idx1] =
Expand All @@ -490,9 +491,9 @@ function dss_untransform!(
for fidx in contravariant12fidx
for level in 1:nlevels
data_idx1 =
_get_idx(sizet_data, (level, ip, jp, fidx, elem))
linear_ind(sizet_data, (level, ip, jp, fidx, elem))
data_idx2 =
_get_idx(sizet_data, (level, ip, jp, fidx + 1, elem))
linear_ind(sizet_data, (level, ip, jp, fidx + 1, elem))
(idx11, idx12, idx21, idx22) =
_get_idx_metric(sizet_metric, (level, ip, jp, elem))
pdata[data_idx1] =
Expand Down Expand Up @@ -520,7 +521,7 @@ function dss_load_perimeter_data!(
sizet = (nlevels, Nq, Nq, nfid, nelems)
for elem in 1:nelems, (p, (ip, jp)) in enumerate(perimeter)
for fidx in 1:nfid, level in 1:nlevels
idx = _get_idx(sizet, (level, ip, jp, fidx, elem))
idx = linear_ind(sizet, (level, ip, jp, fidx, elem))
pperimeter_data[level, p, fidx, elem] = pdata[idx]
end
end
Expand All @@ -539,7 +540,7 @@ function dss_unload_perimeter_data!(
sizet = (nlevels, Nq, Nq, nfid, nelems)
for elem in 1:nelems, (p, (ip, jp)) in enumerate(perimeter)
for fidx in 1:nfid, level in 1:nlevels
idx = _get_idx(sizet, (level, ip, jp, fidx, elem))
idx = linear_ind(sizet, (level, ip, jp, fidx, elem))
pdata[idx] = pperimeter_data[level, p, fidx, elem]
end
end
Expand Down
51 changes: 12 additions & 39 deletions src/Topologies/dss_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,47 +231,20 @@ end
end

# helper functions for DSS2
function _get_idx(sizet::NTuple{5, Int}, loc::NTuple{5, Int})
(n1, n2, n3, n4, n5) = sizet
(i1, i2, i3, i4, i5) = loc
return i1 +
((i2 - 1) + ((i3 - 1) + ((i4 - 1) + (i5 - 1) * n4) * n3) * n2) * n1
end

function _get_idx(sizet::NTuple{4, Int}, loc::NTuple{4, Int})
(n1, n2, n3, n4) = sizet
(i1, i2, i3, i4) = loc
return i1 + ((i2 - 1) + ((i3 - 1) + (i4 - 1) * n3) * n2) * n1
end

function _get_idx(sizet::NTuple{3, Int}, idx::Int)
(n1, n2, n3) = sizet
i3 = cld(idx, n1 * n2)
i2 = cld(idx - (i3 - 1) * n1 * n2, n1)
i1 = idx - (i3 - 1) * n1 * n2 - (i2 - 1) * n1
return (i1, i2, i3)
end

function _get_idx(sizet::NTuple{4, Int}, idx::Int)
(n1, n2, n3, n4) = sizet
i4 = cld(idx, n1 * n2 * n3)
i3 = cld(idx - (i4 - 1) * n1 * n2 * n3, n1 * n2)
i2 = cld(idx - (i4 - 1) * n1 * n2 * n3 - (i3 - 1) * n1 * n2, n1)
i1 = idx - (i4 - 1) * n1 * n2 * n3 - (i3 - 1) * n1 * n2 - (i2 - 1) * n1
return (i1, i2, i3, i4)
end

function _get_idx_metric(sizet::NTuple{5, Int}, loc::NTuple{4, Int})
nmetric = sizet[4]
(i11, i12, i21, i22) = nmetric == 4 ? (1, 2, 3, 4) : (1, 2, 4, 5)
(level, i, j, elem) = loc
return (
_get_idx(sizet, (level, i, j, i11, elem)),
_get_idx(sizet, (level, i, j, i12, elem)),
_get_idx(sizet, (level, i, j, i21, elem)),
_get_idx(sizet, (level, i, j, i22, elem)),
)
return nothing
@inbounds begin
nmetric = sizet[4]
(i11, i12, i21, i22) = nmetric == 4 ? (1, 2, 3, 4) : (1, 2, 4, 5)
(level, i, j, elem) = loc
inds = (
linear_ind(sizet, (level, i, j, i11, elem)),
linear_ind(sizet, (level, i, j, i12, elem)),
linear_ind(sizet, (level, i, j, i21, elem)),
linear_ind(sizet, (level, i, j, i22, elem)),
)
return inds
end
end

function _representative_slab(
Expand Down
27 changes: 27 additions & 0 deletions src/Utilities/Utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,31 @@ include("plushalf.jl")
include("unrolled_functions.jl")
include("cache.jl")

"""
cart_ind(n::NTuple, i::Integer)
Returns a `CartesianIndex` from the list
`CartesianIndices(map(x->Base.OneTo(x), n))[i]`
given size `n` and location `i`.
"""
Base.@propagate_inbounds cart_ind(n::NTuple, i::Integer) =
@inbounds CartesianIndices(map(x -> Base.OneTo(x), n))[i]

"""
linear_ind(n::NTuple, ci::CartesianIndex)
linear_ind(n::NTuple, t::NTuple)
Returns a linear index from the list
`LinearIndices(map(x->Base.OneTo(x), n))[ci]`
given size `n` and cartesian index `ci`.
The `linear_ind(n::NTuple, t::NTuple)` wraps `t`
in a `Cartesian` index and calls
`linear_ind(n::NTuple, ci::CartesianIndex)`.
"""
Base.@propagate_inbounds linear_ind(n::NTuple, ci::CartesianIndex) =
@inbounds LinearIndices(map(x -> Base.OneTo(x), n))[ci]
Base.@propagate_inbounds linear_ind(n::NTuple, loc::NTuple) =
linear_ind(n, CartesianIndex(loc))

end # module
Loading

0 comments on commit 5f23521

Please sign in to comment.