Skip to content

Commit

Permalink
Merge pull request #1615 from CliMA/ck/fv_specialize
Browse files Browse the repository at this point in the history
Specialize on diagonal fieldvector broadcasts
  • Loading branch information
charleskawczynski authored Feb 21, 2024
2 parents 202ed63 + 7d31f3b commit da32a43
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 0 deletions.
87 changes: 87 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,93 @@ end
@inline transform_broadcasted(fv::FieldVector, symb, axes) =
parent(getfield(_values(fv), symb))
@inline transform_broadcasted(x, symb, axes) = x

@inline function first_fieldvector_in_bc(args::Tuple, rargs...)
x1 = first_fieldvector_in_bc(args[1], rargs...)
x1 isa FieldVector && return x1
return first_fieldvector_in_bc(Base.tail(args), rargs...)
end

@inline first_fieldvector_in_bc(args::Tuple{Any}, rargs...) =
first_fieldvector_in_bc(args[1], rargs...)
@inline first_fieldvector_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_fieldvector_in_bc(x) = nothing
@inline first_fieldvector_in_bc(x::FieldVector) = x

@inline first_fieldvector_in_bc(
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
) = first_fieldvector_in_bc(bc.args)

@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple,
rargs...,
) where {TStart} =
truesofar &&
_is_diagonal_bc(truesofar, TStart, args[1], rargs...) &&
_is_diagonal_bc_args(truesofar, TStart, Base.tail(args), rargs...)

@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple{Any},
rargs...,
) where {TStart} =
truesofar && _is_diagonal_bc(truesofar, TStart, args[1], rargs...)
@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple{},
rargs...,
) where {TStart} = truesofar

@inline function _is_diagonal_bc(
truesofar,
::Type{TStart},
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
) where {TStart}
return truesofar && _is_diagonal_bc_args(truesofar, TStart, bc.args)
end

@inline _is_diagonal_bc(
truesofar,
::Type{TStart},
::TStart,
) where {TStart <: FieldVector} = true
@inline _is_diagonal_bc(
truesofar,
::Type{TStart},
x::FieldVector,
) where {TStart} = false
@inline _is_diagonal_bc(truesofar, ::Type{TStart}, x) where {TStart} = truesofar

# Find the first fieldvector in the broadcast expression (BCE),
# and compare against every other fieldvector in the BCE
@inline is_diagonal_bc(bc::Base.Broadcast.Broadcasted{FieldVectorStyle}) =
_is_diagonal_bc_args(true, typeof(first_fieldvector_in_bc(bc)), bc.args)

# Specialize on FieldVectorStyle to avoid inference failure
# in fieldvector broadcast expressions:
# https://github.com/JuliaArrays/BlockArrays.jl/issues/310
function Base.Broadcast.instantiate(
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
)
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
axes = Base.Broadcast.combine_axes(bc.args...)
else
axes = bc.axes
# Base.Broadcast.check_broadcast_axes is type-unstable
# for broadcast expressions with multiple fieldvectors.
# So, let's statically elide this when we have "diagonal"
# broadcast expressions:
if !is_diagonal_bc(bc)
Base.Broadcast.check_broadcast_axes(axes, bc.args...)
end
end
return Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, axes)
end

@inline function Base.copyto!(
dest::FieldVector,
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
Expand Down
42 changes: 42 additions & 0 deletions test/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,48 @@ end
@test Y.k.z === 3.0
end

# https://github.com/CliMA/ClimaCore.jl/issues/1465
@testset "Diagonal FieldVector broadcast expressions" begin
FT = Float64
device = ClimaComms.device()
comms_ctx = ClimaComms.context(device)
cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT; context = comms_ctx)
fspace = TU.FaceExtrudedFiniteDifferenceSpace(FT; context = comms_ctx)
cx = Fields.fill((; a = FT(1), b = FT(2)), cspace)
cy = Fields.fill((; a = FT(1), b = FT(2)), cspace)
fx = Fields.fill((; a = FT(1), b = FT(2)), fspace)
fy = Fields.fill((; a = FT(1), b = FT(2)), fspace)
Y1 = Fields.FieldVector(; x = cx, y = cy)
Y2 = Fields.FieldVector(; x = cx, y = cy)
Y3 = Fields.FieldVector(; x = cx, y = cy)
Y4 = Fields.FieldVector(; x = cx, y = cy)
Z = Fields.FieldVector(; x = fx, y = fy)
function test_fv_allocations!(X1, X2, X3, X4)
@. X1 += X2 * X3 + X4
return nothing
end
test_fv_allocations!(Y1, Y2, Y3, Y4)
p_allocated = @allocated test_fv_allocations!(Y1, Y2, Y3, Y4)
if device isa ClimaComms.AbstractCPUDevice
@test p_allocated == 0
elseif device isa ClimaComms.CUDADevice
@test_broken p_allocated == 0
end

bc1 = Base.broadcasted(
:-,
Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y2)),
Base.broadcasted(:*, 3, Y3),
)
bc2 = Base.broadcasted(
:-,
Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y1)),
Base.broadcasted(:*, 3, Z),
)
@test Fields.is_diagonal_bc(bc1)
@test !Fields.is_diagonal_bc(bc2)
end

function call_getcolumn(fv, colidx)
@allowscalar fvcol = fv[colidx]
nothing
Expand Down

0 comments on commit da32a43

Please sign in to comment.