Skip to content

Commit

Permalink
Remove multiple integer indexing in DataLayouts
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Aug 9, 2024
1 parent f8b7ad4 commit 51f60b8
Show file tree
Hide file tree
Showing 31 changed files with 534 additions and 514 deletions.
19 changes: 10 additions & 9 deletions ext/cuda/limiters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import ClimaCore.Limiters:
apply_limiter!
import ClimaCore.Fields
import ClimaCore: DataLayouts, Spaces, Topologies, Fields
import ClimaCore.DataLayouts: slab_index
using CUDA

function config_threadblock(Nv, Nh)
Expand Down Expand Up @@ -63,7 +64,7 @@ function compute_element_bounds_kernel!(
slab_ρ = slab(ρ, v, h)
for j in 1:Nj
for i in 1:Ni
q = rdiv(slab_ρq[i, j], slab_ρ[i, j])
q = rdiv(slab_ρq[slab_index(i, j)], slab_ρ[slab_index(i, j)])
if i == 1 && j == 1
q_min = q
q_max = q
Expand All @@ -74,8 +75,8 @@ function compute_element_bounds_kernel!(
end
end
slab_q_bounds = slab(q_bounds, v, h)
slab_q_bounds[1] = q_min
slab_q_bounds[2] = q_max
slab_q_bounds[slab_index(1)] = q_min
slab_q_bounds[slab_index(2)] = q_max
end
return nothing
end
Expand Down Expand Up @@ -123,18 +124,18 @@ function compute_neighbor_bounds_local_kernel!(
(v, h) = kernel_indexes(tidx, n).I
(; q_bounds, q_bounds_nbr, ghost_buffer, rtol) = limiter
slab_q_bounds = slab(q_bounds, v, h)
q_min = slab_q_bounds[1]
q_max = slab_q_bounds[2]
q_min = slab_q_bounds[slab_index(1)]
q_max = slab_q_bounds[slab_index(2)]
for lne in
local_neighbor_elem_offset[h]:(local_neighbor_elem_offset[h + 1] - 1)
h_nbr = local_neighbor_elem[lne]
slab_q_bounds = slab(q_bounds, v, h_nbr)
q_min = rmin(q_min, slab_q_bounds[1])
q_max = rmax(q_max, slab_q_bounds[2])
q_min = rmin(q_min, slab_q_bounds[slab_index(1)])
q_max = rmax(q_max, slab_q_bounds[slab_index(2)])
end
slab_q_bounds_nbr = slab(q_bounds_nbr, v, h)
slab_q_bounds_nbr[1] = q_min
slab_q_bounds_nbr[2] = q_max
slab_q_bounds_nbr[slab_index(1)] = q_min
slab_q_bounds_nbr[slab_index(2)] = q_max
end
return nothing
end
Expand Down
35 changes: 19 additions & 16 deletions ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import ClimaCore.Fields
import ClimaCore.Spaces
import ClimaCore.Topologies
import ClimaCore.MatrixFields
import ClimaCore.DataLayouts: vindex
import ClimaCore.MatrixFields: single_field_solve!
import ClimaCore.MatrixFields: _single_field_solve!
import ClimaCore.MatrixFields: band_matrix_solve!, unzip_tuple_field_values
Expand Down Expand Up @@ -71,7 +72,7 @@ function _single_field_solve!(
b_data = Fields.field_values(b)
Nv = DataLayouts.nlevels(x_data)
@inbounds for v in 1:Nv
x_data[v] = inv(A.λ) b_data[v]
x_data[vindex(v)] = inv(A.λ) b_data[vindex(v)]
end
end

Expand All @@ -98,6 +99,7 @@ function band_matrix_solve_local_mem!(
Nv = DataLayouts.nlevels(x)
Ux, U₊₁ = cache
A₋₁, A₀, A₊₁ = Aⱼs
vi = vindex

Ux_local = MArray{Tuple{Nv}, eltype(Ux)}(undef)
U₊₁_local = MArray{Tuple{Nv}, eltype(U₊₁)}(undef)
Expand All @@ -107,16 +109,16 @@ function band_matrix_solve_local_mem!(
A₊₁_local = MArray{Tuple{Nv}, eltype(A₊₁)}(undef)
b_local = MArray{Tuple{Nv}, eltype(b)}(undef)
@inbounds for v in 1:Nv
A₋₁_local[v] = A₋₁[v]
A₀_local[v] = A₀[v]
A₊₁_local[v] = A₊₁[v]
b_local[v] = b[v]
A₋₁_local[v] = A₋₁[vi(v)]
A₀_local[v] = A₀[vi(v)]
A₊₁_local[v] = A₊₁[vi(v)]
b_local[v] = b[vi(v)]
end
cache_local = (Ux_local, U₊₁_local)
Aⱼs_local = (A₋₁, A₀, A₊₁)
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local)
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local, vi)
@inbounds for v in 1:Nv
x[v] = x_local[v]
x[vi(v)] = x_local[v]
end
return nothing
end
Expand All @@ -128,6 +130,7 @@ function band_matrix_solve_local_mem!(
Aⱼs,
b,
)
vi = vindex
Nv = DataLayouts.nlevels(x)
Ux, U₊₁, U₊₂ = cache
A₋₂, A₋₁, A₀, A₊₁, A₊₂ = Aⱼs
Expand All @@ -142,18 +145,18 @@ function band_matrix_solve_local_mem!(
A₊₂_local = MArray{Tuple{Nv}, eltype(A₊₂)}(undef)
b_local = MArray{Tuple{Nv}, eltype(b)}(undef)
@inbounds for v in 1:Nv
A₋₂_local[v] = A₋₂[v]
A₋₁_local[v] = A₋₁[v]
A₀_local[v] = A₀[v]
A₊₁_local[v] = A₊₁[v]
A₊₂_local[v] = A₊₂[v]
b_local[v] = b[v]
A₋₂_local[v] = A₋₂[vi(v)]
A₋₁_local[v] = A₋₁[vi(v)]
A₀_local[v] = A₀[vi(v)]
A₊₁_local[v] = A₊₁[vi(v)]
A₊₂_local[v] = A₊₂[vi(v)]
b_local[v] = b[vi(v)]
end
cache_local = (Ux_local, U₊₁_local, U₊₂_local)
Aⱼs_local = (A₋₂, A₋₁, A₀, A₊₁, A₊₂)
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local)
band_matrix_solve!(t, cache_local, x_local, Aⱼs_local, b_local, vi)
@inbounds for v in 1:Nv
x[v] = x_local[v]
x[vi(v)] = x_local[v]
end
return nothing
end
Expand All @@ -168,7 +171,7 @@ function band_matrix_solve_local_mem!(
Nv = DataLayouts.nlevels(x)
(A₀,) = Aⱼs
@inbounds for v in 1:Nv
x[v] = inv(A₀[v]) b[v]
x[vindex(v)] = inv(A₀[vindex(v)]) b[vindex(v)]
end
return nothing
end
19 changes: 8 additions & 11 deletions ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function set_interpolated_values_kernel!(
totalThreadsZ = gridDim().z * blockDim().z

_, Nq = size(I1)

CI = CartesianIndex
for i in hindex:totalThreadsX:num_horiz
h = local_horiz_indices[i]
for j in vindex:totalThreadsY:num_vert
Expand All @@ -73,8 +73,8 @@ function set_interpolated_values_kernel!(
I1[i, t] *
I2[i, s] *
(
A * field_values[k][t, s, nothing, v_lo, h] +
B * field_values[k][t, s, nothing, v_hi, h]
A * field_values[k][CI(t, s, 1, v_lo, h)] +
B * field_values[k][CI(t, s, 1, v_hi, h)]
)
end
end
Expand Down Expand Up @@ -107,7 +107,7 @@ function set_interpolated_values_kernel!(
totalThreadsZ = gridDim().z * blockDim().z

_, Nq = size(I)

CI = CartesianIndex
for i in hindex:totalThreadsX:num_horiz
h = local_horiz_indices[i]
for j in vindex:totalThreadsY:num_vert
Expand All @@ -121,10 +121,8 @@ function set_interpolated_values_kernel!(
I[i, t] *
I[i, s] *
(
A *
field_values[k][t, nothing, nothing, v_lo, h] +
B *
field_values[k][t, nothing, nothing, v_hi, h]
A * field_values[k][CI(t, 1, 1, v_lo, h)] +
B * field_values[k][CI(t, 1, 1, v_hi, h)]
)
end
end
Expand Down Expand Up @@ -199,7 +197,7 @@ function set_interpolated_values_kernel!(
out[i, k] +=
I1[i, t] *
I2[i, s] *
field_values[k][t, s, nothing, nothing, h]
field_values[k][CartesianIndex(t, s, 1, 1, h)]
end
end
end
Expand Down Expand Up @@ -232,8 +230,7 @@ function set_interpolated_values_kernel!(
out[i, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, k] +=
I[i, i] *
field_values[k][t, nothing, nothing, nothing, h]
I[i, i] * field_values[k][CartesianIndex(t, 1, 1, 1, h)]
end
end
end
Expand Down
8 changes: 5 additions & 3 deletions lib/ClimaCorePlots/src/ClimaCorePlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import RecipesBase
import TriplotBase

import ClimaComms
import ClimaCore.DataLayouts: slab_index
import ClimaCore:
ClimaCore,
DataLayouts,
Expand Down Expand Up @@ -308,7 +309,7 @@ function _slice_along(field, coord)
hdata = ClimaCore.slab(hcoord_data, hidx)
hnode_idx = 1
for i in axes(hdata)[axis]
pt = axis == 1 ? hdata[i, 1] : hdata[1, i]
pt = axis == 1 ? hdata[slab_index(i, 1)] : hdata[slab_index(1, i)]
axis_value = Geometry.component(pt, axis)
coord_value = Geometry.component(coord, 1)
if axis_value > coord_value
Expand Down Expand Up @@ -353,8 +354,9 @@ function _slice_along(field, coord)
islab = ClimaCore.slab(ortho_data, v, i)
# copy the nodal data
for ni in 1:size(islab)[1]
islab[ni] =
axis == 1 ? ijslab[hnode_idx, ni] : ijslab[ni, hnode_idx]
islab[slab_index(ni)] =
axis == 1 ? ijslab[slab_index(hnode_idx, ni)] :
ijslab[slab_index(ni, hnode_idx)]
end
end
end
Expand Down
7 changes: 4 additions & 3 deletions lib/ClimaCoreTempestRemap/src/netcdf.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import CommonDataModel
import ClimaCore: slab, column
import ClimaCore.DataLayouts: slab_index

"""
def_time_coord(nc::NCDataset, length=Inf, eltype=Float64;
Expand Down Expand Up @@ -97,7 +98,7 @@ function def_space_coord(
coords = Spaces.coordinates_data(space)

for (col, ((i, j), e)) in enumerate(nodes)
coord = slab(coords, e)[i, j]
coord = slab(coords, e)[slab_index(i, j)]
X[col] = coord.x
Y[col] = coord.y
end
Expand Down Expand Up @@ -149,7 +150,7 @@ function def_space_coord(
coords = Spaces.coordinates_data(space)

for (col, ((i, j), e)) in enumerate(nodes)
coord = slab(coords, e)[i, j]
coord = slab(coords, e)[slab_index(i, j)]
lon[col] = coord.long
lat[col] = coord.lat
end
Expand Down Expand Up @@ -328,7 +329,7 @@ function Base.setindex!(
end
data = Fields.field_values(field)
for (col, ((i, j), e)) in enumerate(nodes)
var[col, extraidx...] = slab(data, e)[i, j]
var[col, extraidx...] = slab(data, e)[slab_index(i, j)]
end
return var
end
Expand Down
Loading

0 comments on commit 51f60b8

Please sign in to comment.