Skip to content

Commit

Permalink
Add column_accumulate and generalize column_reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Aug 1, 2024
1 parent cf881de commit 9b0a31d
Show file tree
Hide file tree
Showing 12 changed files with 471 additions and 598 deletions.
4 changes: 2 additions & 2 deletions .buildkite/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,9 @@ weakdeps = ["SparseArrays"]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[[deps.ClimaComms]]
git-tree-sha1 = "2ca8c9ca6131a7be8ca262e6db79bc7aa94ab597"
git-tree-sha1 = "ec303a4a66dc0a0ebe15a639a7e685afeaa0daef"
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.3"
version = "0.6.4"
weakdeps = ["CUDA", "MPI"]

[deps.ClimaComms.extensions]
Expand Down
19 changes: 4 additions & 15 deletions benchmarks/bickleyjet/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a"
version = "1.18.0+2"

[[deps.ClimaComms]]
git-tree-sha1 = "2ca8c9ca6131a7be8ca262e6db79bc7aa94ab597"
git-tree-sha1 = "ec303a4a66dc0a0ebe15a639a7e685afeaa0daef"
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.3"
version = "0.6.4"

[deps.ClimaComms.extensions]
ClimaCommsCUDAExt = "CUDA"
Expand All @@ -201,10 +201,10 @@ version = "0.6.3"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "Unrolled"]
path = "../.."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.14.9"
version = "0.14.10"

[deps.ClimaCore.extensions]
ClimaCoreCUDAExt = "CUDA"
Expand Down Expand Up @@ -586,11 +586,6 @@ git-tree-sha1 = "ca0f6bf568b4bfc807e7537f081c81e35ceca114"
uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8"
version = "2.10.0+0"

[[deps.IfElse]]
git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.1"

[[deps.InlineStrings]]
deps = ["Parsers"]
git-tree-sha1 = "86356004f30f8e737eff143d57d41bd580e437aa"
Expand Down Expand Up @@ -1282,12 +1277,6 @@ version = "2.4.0"
[deps.SpecialFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[[deps.Static]]
deps = ["IfElse"]
git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.8.10"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "6e00379a24597be4ae1ee6b2d882e15392040132"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ Extrapolate
```@docs
column_integral_definite!
column_integral_indefinite!
column_mapreduce!
column_reduce!
column_accumulate!
```

## Internal APIs
Expand Down
181 changes: 69 additions & 112 deletions ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
@@ -1,126 +1,83 @@
import ClimaCore: Spaces, Fields, Spaces, Topologies
import ClimaCore.Operators: strip_space
import ClimaCore: Spaces, Fields, level, column
import ClimaCore.Operators:
column_integral_definite!,
column_integral_definite_kernel!,
column_integral_indefinite_kernel!,
column_integral_indefinite!,
column_mapreduce_device!,
_column_integral_definite!,
_column_integral_indefinite!

left_idx,
strip_space,
column_reduce_device!,
single_column_reduce!,
column_accumulate_device!,
single_column_accumulate!
import ClimaComms
using CUDA: @cuda

function column_integral_definite!(
function column_reduce_device!(
::ClimaComms.CUDADevice,
∫field::Fields.Field,
ᶜfield::Fields.Field,
)
space = axes(∫field)
Ni, Nj, _, _, Nh = size(Fields.field_values(∫field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
args = (strip_space(∫field, space), strip_space(ᶜfield, space))
auto_launch!(
column_integral_definite_kernel!,
args,
size(Fields.field_values(∫field));
threads_s = nthreads,
blocks_s = nblocks,
)
end

function column_integral_definite_kernel!(
∫field,
ᶜfield::Fields.CenterExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶜfield))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
∫field_column = Spaces.column(∫field, i, j, h)
ᶜfield_column = Spaces.column(ᶜfield, i, j, h)
_column_integral_definite!(∫field_column, ᶜfield_column)
end
return nothing
end

function column_integral_indefinite_kernel!(
ᶠ∫field::Fields.FaceExtrudedFiniteDifferenceField,
ᶜfield::Fields.CenterExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶜfield))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
ᶠ∫field_column = Spaces.column(ᶠ∫field, i, j, h)
ᶜfield_column = Spaces.column(ᶜfield, i, j, h)
_column_integral_indefinite!(ᶠ∫field_column, ᶜfield_column)
end
return nothing
end

function column_integral_indefinite!(
::ClimaComms.CUDADevice,
ᶠ∫field::Fields.Field,
ᶜfield::Fields.Field,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶠ∫field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
args = (ᶠ∫field, ᶜfield)
auto_launch!(
column_integral_indefinite_kernel!,
args,
size(Fields.field_values(ᶠ∫field));
threads_s = nthreads,
blocks_s = nblocks,
f::F,
transform::T,
output,
input,
init,
space,
) where {F, T}
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
threads_s, blocks_s = _configure_threadblock(Ni * Nj * Nh)
args = (
single_column_reduce!,
f,
transform,
strip_space(output, axes(output)), # The output space is irrelevant here
strip_space(input, space),
init,
space,
)
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
end

function column_mapreduce_device!(
function column_accumulate_device!(
::ClimaComms.CUDADevice,
fn::F,
op::O,
reduced_field::Fields.Field,
fields::Fields.Field...,
) where {F, O}
Ni, Nj, _, _, Nh = size(Fields.field_values(reduced_field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
kernel! = if first(fields) isa Fields.ExtrudedFiniteDifferenceField
column_mapreduce_kernel_extruded!
else
column_mapreduce_kernel!
end
f::F,
transform::T,
output,
input,
init,
space,
) where {F, T}
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
threads_s, blocks_s = _configure_threadblock(Ni * Nj * Nh)
args = (
fn,
op,
# reduced_field,
strip_space(reduced_field, axes(reduced_field)),
# fields...,
map(field -> strip_space(field, axes(field)), fields)...,
)
auto_launch!(
kernel!,
args,
size(Fields.field_values(reduced_field));
threads_s = nthreads,
blocks_s = nblocks,
single_column_accumulate!,
f,
transform,
strip_space(output, space),
strip_space(input, space),
init,
space,
)
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
end

function column_mapreduce_kernel_extruded!(
fn::F,
op::O,
reduced_field,
fields...,
) where {F, O}
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(reduced_field))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
reduced_field_column = Spaces.column(reduced_field, i, j, h)
field_columns = map(field -> Spaces.column(field, i, j, h), fields)
_column_mapreduce!(fn, op, reduced_field_column, field_columns...)
bycolumn_kernel!(
single_column_function!::S,
f::F,
transform::T,
output,
input,
init,
space,
) where {S, F, T} =
if space isa Spaces.FiniteDifferenceSpace
single_column_function!(f, transform, output, input, init, space)
else
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
single_column_function!(
f,
transform,
column(output, i, j, h),
column(input, i, j, h),
init,
column(space, i, j, h),
)
end
end
return nothing
end
19 changes: 4 additions & 15 deletions lib/ClimaCoreMakie/examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ weakdeps = ["SparseArrays"]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[[deps.ClimaComms]]
git-tree-sha1 = "2ca8c9ca6131a7be8ca262e6db79bc7aa94ab597"
git-tree-sha1 = "ec303a4a66dc0a0ebe15a639a7e685afeaa0daef"
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.3"
version = "0.6.4"

[deps.ClimaComms.extensions]
ClimaCommsCUDAExt = "CUDA"
Expand All @@ -218,10 +218,10 @@ version = "0.6.3"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "Unrolled"]
path = "../../.."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.14.9"
version = "0.14.10"

[deps.ClimaCore.extensions]
ClimaCoreCUDAExt = "CUDA"
Expand Down Expand Up @@ -754,11 +754,6 @@ git-tree-sha1 = "47ac8cc196b81001a711f4b2c12c97372338f00c"
uuid = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
version = "1.24.2"

[[deps.IfElse]]
git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.1"

[[deps.ImageAxes]]
deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"]
git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8"
Expand Down Expand Up @@ -1633,12 +1628,6 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c"
uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
version = "0.1.1"

[[deps.Static]]
deps = ["IfElse"]
git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.8.10"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "6e00379a24597be4ae1ee6b2d882e15392040132"
Expand Down
5 changes: 5 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ Base.similar(fv::FieldVector{T}, ::Type{T′}) where {T, T′} =
Base.copy(fv::FieldVector{T}) where {T} = FieldVector{T}(map(copy, _values(fv)))
Base.zero(fv::FieldVector{T}) where {T} = FieldVector{T}(map(zero, _values(fv)))

Base.@propagate_inbounds slab(fv::FieldVector{T}, inds...) where {T} =
FieldVector{T}(slab_args(_values(fv), inds...))
Base.@propagate_inbounds column(fv::FieldVector{T}, inds...) where {T} =
FieldVector{T}(column_args(_values(fv), inds...))

struct FieldVectorStyle <: Base.Broadcast.AbstractArrayStyle{1} end

Base.Broadcast.BroadcastStyle(::Type{<:FieldVector}) = FieldVectorStyle()
Expand Down
38 changes: 5 additions & 33 deletions src/Fields/indices.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,9 @@

Base.@propagate_inbounds Base.getindex(field::Field, colidx::ColumnIndex) =
column(field, colidx)
Base.@propagate_inbounds function Base.getindex(
fv::FieldVector{T},
const ColumnIndexable =
Union{Field, FieldVector, Base.AbstractBroadcasted, Spaces.AbstractSpace}
Base.@propagate_inbounds Base.getindex(
x::ColumnIndexable,
colidx::ColumnIndex,
) where {T}
values = map(x -> x[colidx], _values(fv))
return FieldVector{T, typeof(values)}(values)
end
Base.@propagate_inbounds function column(
field::SpectralElementField1D,
colidx::ColumnIndex{1},
)
column(field, colidx.ij[1], colidx.h)
end
Base.@propagate_inbounds function column(
field::ExtrudedFiniteDifferenceField,
colidx::ColumnIndex{1},
)
column(field, colidx.ij[1], colidx.h)
end
Base.@propagate_inbounds function column(
field::SpectralElementField2D,
colidx::ColumnIndex{2},
)
column(field, colidx.ij[1], colidx.ij[2], colidx.h)
end
Base.@propagate_inbounds function column(
field::ExtrudedFiniteDifferenceField,
colidx::ColumnIndex{2},
)
column(field, colidx.ij[1], colidx.ij[2], colidx.h)
end
) = column(x, colidx.ij..., colidx.h)

"""
Fields.bycolumn(fn, space)
Expand Down
5 changes: 5 additions & 0 deletions src/Operators/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,8 @@ end
strip_space_args(::Tuple{}, space) = ()
strip_space_args(args::Tuple, space) =
(strip_space(args[1], space), strip_space_args(Base.tail(args), space)...)

function unstrip_space(field::Field, parent_space)
new_space = reconstruct_placeholder_space(axes(field), parent_space)
return Field(Fields.field_values(field), new_space)
end
9 changes: 9 additions & 0 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3299,6 +3299,15 @@ function Base.Broadcast.materialize!(
)
end

Base.@propagate_inbounds column(op::FiniteDifferenceOperator, inds...) =
unionall_type(typeof(op))(column_args(op.bcs, inds...))
Base.@propagate_inbounds column(sbc::StencilBroadcasted{S}, inds...) where {S} =
StencilBroadcasted{S}(
column(sbc.op, inds...),
column_args(sbc.args, inds...),
column(sbc.axes, inds...),
)

#TODO: the optimizer dies with column broadcast expressions over a certain complexity
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
Expand Down
Loading

0 comments on commit 9b0a31d

Please sign in to comment.