Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add column_accumulate and generalize column_reduce #1903

Merged
merged 1 commit into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .buildkite/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,9 @@ weakdeps = ["SparseArrays"]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[[deps.ClimaComms]]
git-tree-sha1 = "2ca8c9ca6131a7be8ca262e6db79bc7aa94ab597"
git-tree-sha1 = "ec303a4a66dc0a0ebe15a639a7e685afeaa0daef"
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.3"
version = "0.6.4"
weakdeps = ["CUDA", "MPI"]

[deps.ClimaComms.extensions]
Expand Down
19 changes: 4 additions & 15 deletions benchmarks/bickleyjet/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,9 @@ uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a"
version = "1.18.0+2"

[[deps.ClimaComms]]
git-tree-sha1 = "2ca8c9ca6131a7be8ca262e6db79bc7aa94ab597"
git-tree-sha1 = "ec303a4a66dc0a0ebe15a639a7e685afeaa0daef"
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.3"
version = "0.6.4"

[deps.ClimaComms.extensions]
ClimaCommsCUDAExt = "CUDA"
Expand All @@ -201,10 +201,10 @@ version = "0.6.3"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "Unrolled"]
path = "../.."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.14.9"
version = "0.14.10"

[deps.ClimaCore.extensions]
ClimaCoreCUDAExt = "CUDA"
Expand Down Expand Up @@ -586,11 +586,6 @@ git-tree-sha1 = "ca0f6bf568b4bfc807e7537f081c81e35ceca114"
uuid = "e33a78d0-f292-5ffc-b300-72abe9b543c8"
version = "2.10.0+0"

[[deps.IfElse]]
git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.1"

[[deps.InlineStrings]]
deps = ["Parsers"]
git-tree-sha1 = "86356004f30f8e737eff143d57d41bd580e437aa"
Expand Down Expand Up @@ -1282,12 +1277,6 @@ version = "2.4.0"
[deps.SpecialFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[[deps.Static]]
deps = ["IfElse"]
git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.8.10"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "6e00379a24597be4ae1ee6b2d882e15392040132"
Expand Down
3 changes: 2 additions & 1 deletion docs/src/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ Extrapolate
```@docs
column_integral_definite!
column_integral_indefinite!
column_mapreduce!
column_reduce!
column_accumulate!
```

## Internal APIs
Expand Down
181 changes: 69 additions & 112 deletions ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
@@ -1,126 +1,83 @@
import ClimaCore: Spaces, Fields, Spaces, Topologies
import ClimaCore.Operators: strip_space
import ClimaCore: Spaces, Fields, level, column
import ClimaCore.Operators:
column_integral_definite!,
column_integral_definite_kernel!,
column_integral_indefinite_kernel!,
column_integral_indefinite!,
column_mapreduce_device!,
_column_integral_definite!,
_column_integral_indefinite!

left_idx,
strip_space,
column_reduce_device!,
single_column_reduce!,
column_accumulate_device!,
single_column_accumulate!
import ClimaComms
using CUDA: @cuda

function column_integral_definite!(
function column_reduce_device!(
::ClimaComms.CUDADevice,
∫field::Fields.Field,
ᶜfield::Fields.Field,
)
space = axes(∫field)
Ni, Nj, _, _, Nh = size(Fields.field_values(∫field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
args = (strip_space(∫field, space), strip_space(ᶜfield, space))
auto_launch!(
column_integral_definite_kernel!,
args,
size(Fields.field_values(∫field));
threads_s = nthreads,
blocks_s = nblocks,
)
end

function column_integral_definite_kernel!(
∫field,
ᶜfield::Fields.CenterExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶜfield))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
∫field_column = Spaces.column(∫field, i, j, h)
ᶜfield_column = Spaces.column(ᶜfield, i, j, h)
_column_integral_definite!(∫field_column, ᶜfield_column)
end
return nothing
end

function column_integral_indefinite_kernel!(
ᶠ∫field::Fields.FaceExtrudedFiniteDifferenceField,
ᶜfield::Fields.CenterExtrudedFiniteDifferenceField,
)
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶜfield))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
ᶠ∫field_column = Spaces.column(ᶠ∫field, i, j, h)
ᶜfield_column = Spaces.column(ᶜfield, i, j, h)
_column_integral_indefinite!(ᶠ∫field_column, ᶜfield_column)
end
return nothing
end

function column_integral_indefinite!(
::ClimaComms.CUDADevice,
ᶠ∫field::Fields.Field,
ᶜfield::Fields.Field,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(ᶠ∫field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
args = (ᶠ∫field, ᶜfield)
auto_launch!(
column_integral_indefinite_kernel!,
args,
size(Fields.field_values(ᶠ∫field));
threads_s = nthreads,
blocks_s = nblocks,
f::F,
transform::T,
output,
input,
init,
space,
) where {F, T}
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
threads_s, blocks_s = _configure_threadblock(Ni * Nj * Nh)
args = (
single_column_reduce!,
f,
transform,
strip_space(output, axes(output)), # The output space is irrelevant here
strip_space(input, space),
init,
space,
)
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
end

function column_mapreduce_device!(
function column_accumulate_device!(
::ClimaComms.CUDADevice,
fn::F,
op::O,
reduced_field::Fields.Field,
fields::Fields.Field...,
) where {F, O}
Ni, Nj, _, _, Nh = size(Fields.field_values(reduced_field))
nthreads, nblocks = _configure_threadblock(Ni * Nj * Nh)
kernel! = if first(fields) isa Fields.ExtrudedFiniteDifferenceField
column_mapreduce_kernel_extruded!
else
column_mapreduce_kernel!
end
f::F,
transform::T,
output,
input,
init,
space,
) where {F, T}
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
threads_s, blocks_s = _configure_threadblock(Ni * Nj * Nh)
args = (
fn,
op,
# reduced_field,
strip_space(reduced_field, axes(reduced_field)),
# fields...,
map(field -> strip_space(field, axes(field)), fields)...,
)
auto_launch!(
kernel!,
args,
size(Fields.field_values(reduced_field));
threads_s = nthreads,
blocks_s = nblocks,
single_column_accumulate!,
f,
transform,
strip_space(output, space),
strip_space(input, space),
init,
space,
)
auto_launch!(bycolumn_kernel!, args, (); threads_s, blocks_s)
end

function column_mapreduce_kernel_extruded!(
fn::F,
op::O,
reduced_field,
fields...,
) where {F, O}
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(reduced_field))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
reduced_field_column = Spaces.column(reduced_field, i, j, h)
field_columns = map(field -> Spaces.column(field, i, j, h), fields)
_column_mapreduce!(fn, op, reduced_field_column, field_columns...)
bycolumn_kernel!(
single_column_function!::S,
f::F,
transform::T,
output,
input,
init,
space,
) where {S, F, T} =
if space isa Spaces.FiniteDifferenceSpace
single_column_function!(f, transform, output, input, init, space)
else
idx = threadIdx().x + (blockIdx().x - 1) * blockDim().x
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
if idx <= Ni * Nj * Nh
i, j, h = cart_ind((Ni, Nj, Nh), idx).I
single_column_function!(
f,
transform,
column(output, i, j, h),
column(input, i, j, h),
init,
column(space, i, j, h),
)
end
end
return nothing
end
19 changes: 4 additions & 15 deletions lib/ClimaCoreMakie/examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ weakdeps = ["SparseArrays"]
ChainRulesCoreSparseArraysExt = "SparseArrays"

[[deps.ClimaComms]]
git-tree-sha1 = "2ca8c9ca6131a7be8ca262e6db79bc7aa94ab597"
git-tree-sha1 = "ec303a4a66dc0a0ebe15a639a7e685afeaa0daef"
uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.6.3"
version = "0.6.4"

[deps.ClimaComms.extensions]
ClimaCommsCUDAExt = "CUDA"
Expand All @@ -218,10 +218,10 @@ version = "0.6.3"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "Unrolled"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "KrylovKit", "LinearAlgebra", "MultiBroadcastFusion", "NVTX", "PkgVersion", "RecursiveArrayTools", "RootSolvers", "SparseArrays", "StaticArrays", "Statistics", "Unrolled"]
path = "../../.."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.14.9"
version = "0.14.10"

[deps.ClimaCore.extensions]
ClimaCoreCUDAExt = "CUDA"
Expand Down Expand Up @@ -754,11 +754,6 @@ git-tree-sha1 = "47ac8cc196b81001a711f4b2c12c97372338f00c"
uuid = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
version = "1.24.2"

[[deps.IfElse]]
git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1"
uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
version = "0.1.1"

[[deps.ImageAxes]]
deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"]
git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8"
Expand Down Expand Up @@ -1633,12 +1628,6 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c"
uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15"
version = "0.1.1"

[[deps.Static]]
deps = ["IfElse"]
git-tree-sha1 = "d2fdac9ff3906e27f7a618d47b676941baa6c80c"
uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
version = "0.8.10"

[[deps.StaticArrays]]
deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"]
git-tree-sha1 = "6e00379a24597be4ae1ee6b2d882e15392040132"
Expand Down
5 changes: 5 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ Base.similar(fv::FieldVector{T}, ::Type{T′}) where {T, T′} =
Base.copy(fv::FieldVector{T}) where {T} = FieldVector{T}(map(copy, _values(fv)))
Base.zero(fv::FieldVector{T}) where {T} = FieldVector{T}(map(zero, _values(fv)))

Base.@propagate_inbounds slab(fv::FieldVector{T}, inds...) where {T} =
FieldVector{T}(slab_args(_values(fv), inds...))
Base.@propagate_inbounds column(fv::FieldVector{T}, inds...) where {T} =
FieldVector{T}(column_args(_values(fv), inds...))

struct FieldVectorStyle <: Base.Broadcast.AbstractArrayStyle{1} end

Base.Broadcast.BroadcastStyle(::Type{<:FieldVector}) = FieldVectorStyle()
Expand Down
38 changes: 5 additions & 33 deletions src/Fields/indices.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,9 @@

Base.@propagate_inbounds Base.getindex(field::Field, colidx::ColumnIndex) =
column(field, colidx)
Base.@propagate_inbounds function Base.getindex(
fv::FieldVector{T},
const ColumnIndexable =
Union{Field, FieldVector, Base.AbstractBroadcasted, Spaces.AbstractSpace}
Base.@propagate_inbounds Base.getindex(
x::ColumnIndexable,
colidx::ColumnIndex,
) where {T}
values = map(x -> x[colidx], _values(fv))
return FieldVector{T, typeof(values)}(values)
end
Base.@propagate_inbounds function column(
field::SpectralElementField1D,
colidx::ColumnIndex{1},
)
column(field, colidx.ij[1], colidx.h)
end
Base.@propagate_inbounds function column(
field::ExtrudedFiniteDifferenceField,
colidx::ColumnIndex{1},
)
column(field, colidx.ij[1], colidx.h)
end
Base.@propagate_inbounds function column(
field::SpectralElementField2D,
colidx::ColumnIndex{2},
)
column(field, colidx.ij[1], colidx.ij[2], colidx.h)
end
Base.@propagate_inbounds function column(
field::ExtrudedFiniteDifferenceField,
colidx::ColumnIndex{2},
)
column(field, colidx.ij[1], colidx.ij[2], colidx.h)
end
) = column(x, colidx.ij..., colidx.h)

"""
Fields.bycolumn(fn, space)
Expand Down
5 changes: 5 additions & 0 deletions src/Operators/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,8 @@ end
strip_space_args(::Tuple{}, space) = ()
strip_space_args(args::Tuple, space) =
(strip_space(args[1], space), strip_space_args(Base.tail(args), space)...)

function unstrip_space(field::Field, parent_space)
new_space = reconstruct_placeholder_space(axes(field), parent_space)
return Field(Fields.field_values(field), new_space)
end
9 changes: 9 additions & 0 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3299,6 +3299,15 @@ function Base.Broadcast.materialize!(
)
end

Base.@propagate_inbounds column(op::FiniteDifferenceOperator, inds...) =
unionall_type(typeof(op))(column_args(op.bcs, inds...))
Base.@propagate_inbounds column(sbc::StencilBroadcasted{S}, inds...) where {S} =
StencilBroadcasted{S}(
column(sbc.op, inds...),
column_args(sbc.args, inds...),
column(sbc.axes, inds...),
)

#TODO: the optimizer dies with column broadcast expressions over a certain complexity
if hasfield(Method, :recursion_relation)
dont_limit = (args...) -> true
Expand Down
Loading
Loading