Skip to content

Commit

Permalink
Specialize on diagonal fieldvector broadcasts
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Feb 20, 2024
1 parent 202ed63 commit 94f0c63
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
87 changes: 87 additions & 0 deletions src/Fields/fieldvector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,93 @@ end
@inline transform_broadcasted(fv::FieldVector, symb, axes) =
parent(getfield(_values(fv), symb))
@inline transform_broadcasted(x, symb, axes) = x

@inline function first_fieldvector_in_bc(args::Tuple, rargs...)
x1 = first_fieldvector_in_bc(args[1], rargs...)
x1 isa FieldVector && return x1
return first_fieldvector_in_bc(Base.tail(args), rargs...)
end

@inline first_fieldvector_in_bc(args::Tuple{Any}, rargs...) =
first_fieldvector_in_bc(args[1], rargs...)
@inline first_fieldvector_in_bc(args::Tuple{}, rargs...) = nothing
@inline first_fieldvector_in_bc(x) = nothing
@inline first_fieldvector_in_bc(x::FieldVector) = x

@inline first_fieldvector_in_bc(
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
) = first_fieldvector_in_bc(bc.args)

@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple,
rargs...,
) where {TStart} =
truesofar &&
_is_diagonal_bc(truesofar, TStart, args[1], rargs...) &&
_is_diagonal_bc_args(truesofar, TStart, Base.tail(args), rargs...)

@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple{Any},
rargs...,
) where {TStart} =
truesofar && _is_diagonal_bc(truesofar, TStart, args[1], rargs...)
@inline _is_diagonal_bc_args(
truesofar,
::Type{TStart},
args::Tuple{},
rargs...,
) where {TStart} = truesofar

@inline function _is_diagonal_bc(
truesofar,
::Type{TStart},
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
) where {TStart}
return truesofar && _is_diagonal_bc_args(truesofar, TStart, bc.args)
end

@inline _is_diagonal_bc(
truesofar,
::Type{TStart},
::TStart,
) where {TStart <: FieldVector} = true
@inline _is_diagonal_bc(
truesofar,
::Type{TStart},
x::FieldVector,
) where {TStart} = false
@inline _is_diagonal_bc(truesofar, ::Type{TStart}, x) where {TStart} = truesofar

# Find the first fieldvector in the broadcast expression,
# and compare against every other broadca
@inline is_diagonal_bc(bc::Base.Broadcast.Broadcasted{FieldVectorStyle}) =
_is_diagonal_bc_args(true, typeof(first_fieldvector_in_bc(bc)), bc.args)

# Specialize on FieldVectorStyle to avoid inference failure
# in fieldvector broadcast expressions:
# https://github.com/JuliaArrays/BlockArrays.jl/issues/310
function Base.Broadcast.instantiate(
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
)
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
axes = Base.Broadcast.combine_axes(bc.args...)
else
axes = bc.axes
# Base.Broadcast.check_broadcast_axes is type-unstable
# for broadcast expressions with multiple fieldvectors.
# So, let's statically elide this when we have "diagonal"
# broadcast expressions:
if !is_diagonal_bc(bc)
Base.Broadcast.check_broadcast_axes(axes, bc.args...)
end
end
return Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, axes)
end

@inline function Base.copyto!(
dest::FieldVector,
bc::Base.Broadcast.Broadcasted{FieldVectorStyle},
Expand Down
36 changes: 36 additions & 0 deletions test/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,42 @@ end
@test Y.k.z === 3.0
end

# https://github.com/CliMA/ClimaCore.jl/issues/1465
@testset "Diagonal FieldVector broadcast expressions" begin
FT = Float64
cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT)
fspace = TU.FaceExtrudedFiniteDifferenceSpace(FT)
cx = Fields.coordinate_field(cspace)
cy = Fields.coordinate_field(cspace)
fx = Fields.coordinate_field(fspace)
fy = Fields.coordinate_field(fspace)
Y1 = Fields.FieldVector(; x = cx, y = cy)
Y2 = Fields.FieldVector(; x = cx, y = cy)
Y3 = Fields.FieldVector(; x = cx, y = cy)
Y4 = Fields.FieldVector(; x = cx, y = cy)
Z = Fields.FieldVector(; x = fx, y = fy)
function test_fv_allocations!(X1, X2, X3, X4)
@. X1 += X2 * X3 + X4
return nothing
end
test_fv_allocations!(Y1, Y2, Y3, Y4)
p_allocated = @allocated test_fv_allocations!(Y1, Y2, Y3, Y4)
@test p_allocated == 0

bc1 = Base.broadcasted(
:-,
Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y2)),
Base.broadcasted(:*, 3, Y3),
)
bc2 = Base.broadcasted(
:-,
Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y1)),
Base.broadcasted(:*, 3, Z),
)
@test Fields.is_diagonal_bc(bc1)
@test !Fields.is_diagonal_bc(bc2)
end

function call_getcolumn(fv, colidx)
@allowscalar fvcol = fv[colidx]
nothing
Expand Down

0 comments on commit 94f0c63

Please sign in to comment.