Skip to content

Commit

Permalink
Add column_accumulate and generalize column_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Jul 31, 2024
1 parent e2d61b0 commit a574d46
Show file tree
Hide file tree
Showing 11 changed files with 473 additions and 566 deletions.
3 changes: 2 additions & 1 deletion docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ Extrapolate
```@docs
column_integral_definite!
column_integral_indefinite!
column_mapreduce!
column_reduce!
column_accumulate!
```

## Internal APIs
Expand Down
4 changes: 4 additions & 0 deletions ext/cuda/cuda_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import CUDA
import ClimaCore.Fields
import ClimaCore.DataLayouts
import ClimaCore.DataLayouts: empty_kernel_stats
import ClimaCore.Utilities: _assert

get_n_items(field::Fields.Field) = get_n_items(Fields.field_values(field))
get_n_items(data::DataLayouts.AbstractData) = get_n_items(size(data))
Expand All @@ -13,6 +14,9 @@ const reported_stats = Dict()
empty_kernel_stats(::ClimaComms.CUDADevice) = empty!(reported_stats)
collect_kernel_stats() = false

_assert(cond::C, text::T, ::ClimaComms.CUDADevice) where {C, T} =
CUDA.@cuassert cond() text()

"""
auto_launch!(f!::F!, args,
::Union{
Expand Down
181 changes: 69 additions & 112 deletions ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
@@ -1,126 +1,83 @@
import ClimaCore: Spaces, Fields, Spaces, Topologies
import ClimaCore.Operators: strip_space
import ClimaCore: Spaces, Fields, level, column
import ClimaCore.Operators:
column_integral_definite!,
column_integral_definite_kernel!,
column_integral_indefinite_kernel!,
column_integral_indefinite!,
column_mapreduce_device!,
_column_integral_definite!,
_column_integral_indefinite!

left_idx,
strip_space,
column_reduce_device!,
single_column_reduce!,
column_accumulate_device!,
single_column_accumulate!
import ClimaComms
using CUDA: @cuda

function column_integral_definite!(
function column_reduce_device!(
::ClimaComms.CUDADevice,
∫field::Fields.Field,
ᶜfield::Fields.Field,
)
space = axes(∫field)
Ni, Nj, _, _, Nh = size(Fields.field_values(∫field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
args = (strip_space(∫field, space), strip_space(ᶜfield, space))
auto_launch!(
column_integral_definite_kernel!,
args,
size(Fields.field_values(∫field));
threads_s = nthreads,
blocks_s = nblocks,
)
end

function column_integral_definite_kernel!(
∫field,
ᶜfield::Fields.CenterExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶜfield))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
∫field_column = Spaces.column(∫field, i, j, h)
ᶜfield_column = Spaces.column(ᶜfield, i, j, h)
_column_integral_definite!(∫field_column, ᶜfield_column)
end
return nothing
end

function column_integral_indefinite_kernel!(
ᶠ∫field::Fields.FaceExtrudedFiniteDifferenceField,
ᶜfield::Fields.CenterExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶜfield))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
ᶠ∫field_column = Spaces.column(ᶠ∫field, i, j, h)
ᶜfield_column = Spaces.column(ᶜfield, i, j, h)
_column_integral_indefinite!(ᶠ∫field_column, ᶜfield_column)
end
return nothing
end

function column_integral_indefinite!(
::ClimaComms.CUDADevice,
ᶠ∫field::Fields.Field,
ᶜfield::Fields.Field,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶠ∫field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
args = (ᶠ∫field, ᶜfield)
auto_launch!(
column_integral_indefinite_kernel!,
args,
size(Fields.field_values(ᶠ∫field));
threads_s = nthreads,
blocks_s = nblocks,
f::F,
transform::T,
output,
input,
init,
space,
) where {F, T}
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
threads_s, blocks_s = _configure_threadblock(Ni * Nj * Nh)
args = (
single_column_reduce!,
f,
transform,
strip_space(output, axes(output)), # The output space is irrelevant here
strip_space(input, space),
init,
space,
)
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
end

function column_mapreduce_device!(
function column_accumulate_device!(
::ClimaComms.CUDADevice,
fn::F,
op::O,
reduced_field::Fields.Field,
fields::Fields.Field...,
) where {F, O}
Ni, Nj, _, _, Nh = size(Fields.field_values(reduced_field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
kernel! = if first(fields) isa Fields.ExtrudedFiniteDifferenceField
column_mapreduce_kernel_extruded!
else
column_mapreduce_kernel!
end
f::F,
transform::T,
output,
input,
init,
space,
) where {F, T}
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
threads_s, blocks_s = _configure_threadblock(Ni * Nj * Nh)
args = (
fn,
op,
# reduced_field,
strip_space(reduced_field, axes(reduced_field)),
# fields...,
map(field -> strip_space(field, axes(field)), fields)...,
)
auto_launch!(
kernel!,
args,
size(Fields.field_values(reduced_field));
threads_s = nthreads,
blocks_s = nblocks,
single_column_accumulate!,
f,
transform,
strip_space(output, space),
strip_space(input, space),
init,
space,
)
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
end

function column_mapreduce_kernel_extruded!(
fn::F,
op::O,
reduced_field,
fields...,
) where {F, O}
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(reduced_field))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
reduced_field_column = Spaces.column(reduced_field, i, j, h)
field_columns = map(field -> Spaces.column(field, i, j, h), fields)
_column_mapreduce!(fn, op, reduced_field_column, field_columns...)
bycolumn_kernel!(
single_column_function!::S,
f::F,
transform::T,
output,
input,
init,
space,
) where {S, F, T} =
if space isa Spaces.FiniteDifferenceSpace
single_column_function!(f, transform, output, input, init, space)
else
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
single_column_function!(
f,
transform,
column(output, i, j, h),
column(input, i, j, h),
init,
column(space, i, j, h),
)
end
end
return nothing
end
5 changes: 5 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ Base.similar(fv::FieldVector{T}, ::Type{T′}) where {T, T′} =
Base.copy(fv::FieldVector{T}) where {T} = FieldVector{T}(map(copy, _values(fv)))
Base.zero(fv::FieldVector{T}) where {T} = FieldVector{T}(map(zero, _values(fv)))

Base.@propagate_inbounds slab(fv::FieldVector{T}, inds...) where {T} =
FieldVector{T}(slab_args(_values(fv), inds...))
Base.@propagate_inbounds column(fv::FieldVector{T}, inds...) where {T} =
FieldVector{T}(column_args(_values(fv), inds...))

struct FieldVectorStyle <: Base.Broadcast.AbstractArrayStyle{1} end

Base.Broadcast.BroadcastStyle(::Type{<:FieldVector}) = FieldVectorStyle()
Expand Down
38 changes: 5 additions & 33 deletions src/Fields/indices.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,9 @@

Base.@propagate_inbounds Base.getindex(field::Field, colidx::ColumnIndex) =
column(field, colidx)
Base.@propagate_inbounds function Base.getindex(
fv::FieldVector{T},
const ColumnIndexable =
Union{Field, FieldVector, Base.AbstractBroadcasted, Spaces.AbstractSpace}
Base.@propagate_inbounds Base.getindex(
x::ColumnIndexable,
colidx::ColumnIndex,
) where {T}
values = map(x -> x[colidx], _values(fv))
return FieldVector{T, typeof(values)}(values)
end
Base.@propagate_inbounds function column(
field::SpectralElementField1D,
colidx::ColumnIndex{1},
)
column(field, colidx.ij[1], colidx.h)
end
Base.@propagate_inbounds function column(
field::ExtrudedFiniteDifferenceField,
colidx::ColumnIndex{1},
)
column(field, colidx.ij[1], colidx.h)
end
Base.@propagate_inbounds function column(
field::SpectralElementField2D,
colidx::ColumnIndex{2},
)
column(field, colidx.ij[1], colidx.ij[2], colidx.h)
end
Base.@propagate_inbounds function column(
field::ExtrudedFiniteDifferenceField,
colidx::ColumnIndex{2},
)
column(field, colidx.ij[1], colidx.ij[2], colidx.h)
end
) = column(x, colidx.ij..., colidx.h)

"""
Fields.bycolumn(fn, space)
Expand Down
5 changes: 5 additions & 0 deletions src/Operators/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,8 @@ end
strip_space_args(::Tuple{}, space) = ()
strip_space_args(args::Tuple, space) =
(strip_space(args[1], space), strip_space_args(Base.tail(args), space)...)

function unstrip_space(field::Field, parent_space)
new_space = reconstruct_placeholder_space(axes(field), parent_space)
return Field(Fields.field_values(field), new_space)
end
9 changes: 9 additions & 0 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3299,6 +3299,15 @@ function Base.Broadcast.materialize!(
)
end

Base.@propagate_inbounds column(op::FiniteDifferenceOperator, inds...) =
unionall_type(typeof(op))(column_args(op.bcs, inds...))
Base.@propagate_inbounds column(sbc::StencilBroadcasted{S}, inds...) where {S} =
StencilBroadcasted{S}(
column(sbc.op, inds...),
column_args(sbc.args, inds...),
column(sbc.axes, inds...),
)

#TODO: the optimizer dies with column broadcast expressions over a certain complexity
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
Expand Down
Loading

0 comments on commit a574d46

Please sign in to comment.