Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inline all FD operators and axisvector conversions #1462

Merged
merged 3 commits into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ steps:
- label: "Unit: matrix field broadcasting (GPU)"
key: unit_matrix_field_broadcasting_gpu
command: "julia --color=yes --project=test test/MatrixFields/matrix_field_broadcasting.jl"
soft_fail: true
agents:
slurm_gpus: 1
slurm_mem: 40GB
Expand All @@ -551,6 +552,7 @@ steps:
- label: "Unit: operator matrices (GPU)"
key: unit_operator_matrices_gpu
command: "julia --color=yes --project=test test/MatrixFields/operator_matrices.jl"
soft_fail: true
agents:
slurm_gpus: 1
slurm_mem: 40GB
Expand Down
11 changes: 7 additions & 4 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,10 +609,13 @@ end
push!(vals, val)
end
end
return :(@inbounds Axis2Tensor(
(ato, axes(x, 2)),
SMatrix{$(length(Ito)), $M}($(vals...)),
))
quote
Base.@_propagate_inbounds_meta
@inbounds Axis2Tensor(
(ato, axes(x, 2)),
SMatrix{$(length(Ito)), $M}($(vals...)),
)
end
end

@inline transform(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
Expand Down
40 changes: 24 additions & 16 deletions src/Geometry/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,65 +49,73 @@ CovariantVector(
) where {T, I} = local_geometry.gᵢⱼ * u

# Converting to specific dimension types
(::Type{<:ContravariantVector{<:Any, I}})(
@inline (::Type{<:ContravariantVector{<:Any, I}})(
u::ContravariantVector{<:Any, I},
::LocalGeometry{I},
) where {I} = u

(::Type{<:ContravariantVector{<:Any, I}})(
@inline (::Type{<:ContravariantVector{<:Any, I}})(
u::ContravariantVector,
::LocalGeometry,
) where {I} = project(ContravariantAxis{I}(), u)

(::Type{<:ContravariantVector{<:Any, I}})(
@inline (::Type{<:ContravariantVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} =
project(ContravariantAxis{I}(), ContravariantVector(u, local_geometry))

(::Type{<:CovariantVector{<:Any, I}})(
@inline (::Type{<:CovariantVector{<:Any, I}})(
u::CovariantVector{<:Any, I},
::LocalGeometry{I},
) where {I} = u

(::Type{<:CovariantVector{<:Any, I}})(
@inline (::Type{<:CovariantVector{<:Any, I}})(
u::CovariantVector,
::LocalGeometry,
) where {I} = project(CovariantAxis{I}(), u)

(::Type{<:CovariantVector{<:Any, I}})(
@inline (::Type{<:CovariantVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} = project(CovariantAxis{I}(), CovariantVector(u, local_geometry))

(::Type{<:LocalVector{<:Any, I}})(
@inline (::Type{<:LocalVector{<:Any, I}})(
u::LocalVector{<:Any, I},
::LocalGeometry{I},
) where {I} = u

(::Type{<:LocalVector{<:Any, I}})(u::LocalVector, ::LocalGeometry) where {I} =
project(LocalAxis{I}(), u)
@inline (::Type{<:LocalVector{<:Any, I}})(
u::LocalVector,
::LocalGeometry,
) where {I} = project(LocalAxis{I}(), u)

(::Type{<:LocalVector{<:Any, I}})(
@inline (::Type{<:LocalVector{<:Any, I}})(
u::AxisVector,
local_geometry::LocalGeometry,
) where {I} = project(LocalAxis{I}(), LocalVector(u, local_geometry))

# Generic N-axis conversion functions,
# Convert to specific local geometry dimension then convert vector type
LocalVector(u::CovariantVector, local_geometry::LocalGeometry{I}) where {I} =
@inline LocalVector(
u::CovariantVector,
local_geometry::LocalGeometry{I},
) where {I} =
project(LocalAxis{I}(), project(CovariantAxis{I}(), u), local_geometry)

LocalVector(
@inline LocalVector(
u::ContravariantVector,
local_geometry::LocalGeometry{I},
) where {I} =
project(LocalAxis{I}(), project(ContravariantAxis{I}(), u), local_geometry)

CovariantVector(u::LocalVector, local_geometry::LocalGeometry{I}) where {I} =
@inline CovariantVector(
u::LocalVector,
local_geometry::LocalGeometry{I},
) where {I} =
project(CovariantAxis{I}(), project(LocalAxis{I}(), u), local_geometry)

CovariantVector(
@inline CovariantVector(
u::ContravariantVector,
local_geometry::LocalGeometry{I},
) where {I} = project(
Expand All @@ -116,13 +124,13 @@ CovariantVector(
local_geometry,
)

ContravariantVector(
@inline ContravariantVector(
u::LocalVector,
local_geometry::LocalGeometry{I},
) where {I} =
project(ContravariantAxis{I}(), project(LocalAxis{I}(), u), local_geometry)

ContravariantVector(
@inline ContravariantVector(
u::CovariantVector,
local_geometry::LocalGeometry{I},
) where {I} = project(
Expand Down
2 changes: 1 addition & 1 deletion src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3446,7 +3446,7 @@ function Base.copyto!(
max_threads = 256
nitems = Nv * Nq * Nq * Nh # # of independent items
(nthreads, nblocks) = Spaces._configure_threadblock(max_threads, nitems)
@cuda threads = (nthreads,) blocks = (nblocks,) copyto_stencil_kernel!(
@cuda always_inline = true threads = (nthreads,) blocks = (nblocks,) copyto_stencil_kernel!(
strip_space(out, space),
strip_space(bc, space),
axes(out),
Expand Down
52 changes: 52 additions & 0 deletions test/MatrixFields/matrix_field_broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ include("matrix_field_test_utils.jl")
mul!(_result, _ᶜᶜmat, _ᶜvec),
)

GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30
test_field_broadcast_against_array_reference(;
test_name = "tri-diagonal matrix times vector",
get_result = () -> (@. ᶠᶠmat ⋅ ᶠvec),
Expand All @@ -32,6 +34,8 @@ include("matrix_field_test_utils.jl")
ref_set_result! = (_result, _ᶠᶠmat, _ᶠvec) ->
mul!(_result, _ᶠᶠmat, _ᶠvec),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "quad-diagonal matrix times vector",
Expand All @@ -41,6 +45,8 @@ include("matrix_field_test_utils.jl")
ref_set_result! = (_result, _ᶠᶜmat, _ᶜvec) ->
mul!(_result, _ᶠᶜmat, _ᶜvec),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "diagonal matrix times bi-diagonal matrix",
Expand All @@ -50,6 +56,8 @@ include("matrix_field_test_utils.jl")
ref_set_result! = (_result, _ᶜᶜmat, _ᶜᶠmat) ->
mul!(_result, _ᶜᶜmat, _ᶜᶠmat),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "tri-diagonal matrix times tri-diagonal matrix",
Expand All @@ -58,6 +66,8 @@ include("matrix_field_test_utils.jl")
input_fields = (ᶠᶠmat,),
ref_set_result! = (_result, _ᶠᶠmat) -> mul!(_result, _ᶠᶠmat, _ᶠᶠmat),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "quad-diagonal matrix times diagonal matrix",
Expand All @@ -67,6 +77,8 @@ include("matrix_field_test_utils.jl")
ref_set_result! = (_result, _ᶠᶜmat, _ᶜᶜmat) ->
mul!(_result, _ᶠᶜmat, _ᶜᶜmat),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "diagonal matrix times bi-diagonal matrix times \
Expand All @@ -90,6 +102,8 @@ include("matrix_field_test_utils.jl")
mul!(_result, _temp2, _ᶠᶜmat)
end,
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "diagonal matrix times bi-diagonal matrix times \
Expand All @@ -115,6 +129,8 @@ include("matrix_field_test_utils.jl")
end,
test_broken_with_cuda = true, # TODO: Fix this.
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "diagonal matrix times bi-diagonal matrix times \
Expand Down Expand Up @@ -146,6 +162,8 @@ include("matrix_field_test_utils.jl")
mul!(_result, _temp3, _ᶜvec)
end,
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "diagonal matrix times bi-diagonal matrix times \
Expand Down Expand Up @@ -179,6 +197,8 @@ include("matrix_field_test_utils.jl")
time_ratio_limit = 15, # This case's ref function is fast on Buildkite.
test_broken_with_cuda = true, # TODO: Fix this.
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "linear combination of matrix products and LinearAlgebra.I",
Expand Down Expand Up @@ -212,6 +232,8 @@ include("matrix_field_test_utils.jl")
@. _result = _temp3 + _temp4 / 3 - _result
end,
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "another linear combination of matrix products and \
Expand Down Expand Up @@ -246,6 +268,8 @@ include("matrix_field_test_utils.jl")
@. _result = _temp2 * 2 - _temp4 + _result
end,
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "matrix times linear combination",
Expand Down Expand Up @@ -282,6 +306,8 @@ include("matrix_field_test_utils.jl")
mul!(_result, _ᶜᶠmat, _temp5)
end,
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "linear combination times another linear combination",
Expand Down Expand Up @@ -337,6 +363,8 @@ include("matrix_field_test_utils.jl")
end,
max_eps_error_limit = 30, # This case's roundoff error is large on GPUs.
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "matrix times matrix times linear combination times matrix \
Expand Down Expand Up @@ -416,6 +444,8 @@ include("matrix_field_test_utils.jl")
end,
max_eps_error_limit = 70, # This case's roundoff error is large on GPUs.
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast_against_array_reference(;
test_name = "matrix constructions and multiplications",
Expand Down Expand Up @@ -465,8 +495,13 @@ include("matrix_field_test_utils.jl")
mul!(_result, _temp4, _temp6)
end,
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30
end

GC.gc();
@info "mem usage" rss = Sys.maxrss() / 2^30;

@testset "Non-scalar Matrix Field Broadcasting" begin
FT = Float64
center_space, face_space = test_spaces(FT)
Expand Down Expand Up @@ -496,6 +531,8 @@ end
ᶠᶜmat2,
ᶠᶜmat3,
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast(;
test_name = "matrix of covectors times matrix of vectors",
Expand All @@ -507,6 +544,8 @@ end
DiagonalMatrixRow(ᶠlg.gⁱʲ.components.data.:2) ⋅ ᶠᶜmat3
)),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

test_field_broadcast(;
test_name = "matrix of covectors times matrix of vectors times matrix \
Expand All @@ -525,6 +564,8 @@ end
DiagonalMatrixRow(ᶜlg.gⁱʲ.components.data.:2) ⋅ ᶜᶠmat3
)),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

ᶜᶠmat_AC1_num =
map((row1, row2) -> map(tuple, row1, row2), ᶜᶠmat_AC1, ᶜᶠmat)
Expand All @@ -533,6 +574,8 @@ end
ᶠᶜmat_C12_AC1 =
map((row1, row2) -> map(tuple, row1, row2), ᶠᶜmat_C12, ᶠᶜmat_AC1)

GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30
test_field_broadcast(;
test_name = "matrix of covectors and numbers times matrix of vectors \
and covectors times matrix of numbers and vectors times \
Expand All @@ -552,13 +595,17 @@ end
) ⋅ ᶠvec,
)),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30

ᶜvec_NT = @. nested_type(ᶜvec, ᶜvec, ᶜvec)
ᶜᶠmat_NT =
map((rows...) -> map(nested_type, rows...), ᶜᶠmat, ᶜᶠmat2, ᶜᶠmat3)
ᶠᶜmat_NT =
map((rows...) -> map(nested_type, rows...), ᶠᶜmat, ᶠᶜmat2, ᶠᶜmat3)

GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30
test_field_broadcast(;
test_name = "matrix of nested values times matrix of nested values \
times matrix of numbers times matrix of numbers times \
Expand All @@ -572,4 +619,9 @@ end
ᶜᶠmat3 ⋅ ᶠᶜmat ⋅ ᶜᶠmat ⋅ ᶠᶜmat3 ⋅ ᶜvec,
)),
)
GC.gc()
@info "mem usage" rss = Sys.maxrss() / 2^30
end

GC.gc();
@info "mem usage" rss = Sys.maxrss() / 2^30;