diff --git a/src/Fields/fieldvector.jl b/src/Fields/fieldvector.jl index 93cd3d1be3..7ce1d51f8a 100644 --- a/src/Fields/fieldvector.jl +++ b/src/Fields/fieldvector.jl @@ -203,6 +203,93 @@ end @inline transform_broadcasted(fv::FieldVector, symb, axes) = parent(getfield(_values(fv), symb)) @inline transform_broadcasted(x, symb, axes) = x + +@inline function first_fieldvector_in_bc(args::Tuple, rargs...) + x1 = first_fieldvector_in_bc(args[1], rargs...) + x1 isa FieldVector && return x1 + return first_fieldvector_in_bc(Base.tail(args), rargs...) +end + +@inline first_fieldvector_in_bc(args::Tuple{Any}, rargs...) = + first_fieldvector_in_bc(args[1], rargs...) +@inline first_fieldvector_in_bc(args::Tuple{}, rargs...) = nothing +@inline first_fieldvector_in_bc(x) = nothing +@inline first_fieldvector_in_bc(x::FieldVector) = x + +@inline first_fieldvector_in_bc( + bc::Base.Broadcast.Broadcasted{FieldVectorStyle}, +) = first_fieldvector_in_bc(bc.args) + +@inline _is_diagonal_bc_args( + truesofar, + ::Type{TStart}, + args::Tuple, + rargs..., +) where {TStart} = + truesofar && + _is_diagonal_bc(truesofar, TStart, args[1], rargs...) && + _is_diagonal_bc_args(truesofar, TStart, Base.tail(args), rargs...) + +@inline _is_diagonal_bc_args( + truesofar, + ::Type{TStart}, + args::Tuple{Any}, + rargs..., +) where {TStart} = + truesofar && _is_diagonal_bc(truesofar, TStart, args[1], rargs...) +@inline _is_diagonal_bc_args( + truesofar, + ::Type{TStart}, + args::Tuple{}, + rargs..., +) where {TStart} = truesofar + +@inline function _is_diagonal_bc( + truesofar, + ::Type{TStart}, + bc::Base.Broadcast.Broadcasted{FieldVectorStyle}, +) where {TStart} + return truesofar && _is_diagonal_bc_args(truesofar, TStart, bc.args) +end + +@inline _is_diagonal_bc( + truesofar, + ::Type{TStart}, + ::TStart, +) where {TStart <: FieldVector} = true +@inline _is_diagonal_bc( + truesofar, + ::Type{TStart}, + x::FieldVector, +) where {TStart} = false +@inline _is_diagonal_bc(truesofar, ::Type{TStart}, x) where {TStart} = truesofar + +# Find the first fieldvector in the broadcast expression, +# and compare against every other broadca +@inline is_diagonal_bc(bc::Base.Broadcast.Broadcasted{FieldVectorStyle}) = + _is_diagonal_bc_args(true, typeof(first_fieldvector_in_bc(bc)), bc.args) + +# Specialize on FieldVectorStyle to avoid inference failure +# in fieldvector broadcast expressions: +# https://github.com/JuliaArrays/BlockArrays.jl/issues/310 +function Base.Broadcast.instantiate( + bc::Base.Broadcast.Broadcasted{FieldVectorStyle}, +) + if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style}) + axes = Base.Broadcast.combine_axes(bc.args...) + else + axes = bc.axes + # Base.Broadcast.check_broadcast_axes is type-unstable + # for broadcast expressions with multiple fieldvectors. + # So, let's statically elide this when we have "diagonal" + # broadcast expressions: + if !is_diagonal_bc(bc) + Base.Broadcast.check_broadcast_axes(axes, bc.args...) + end + end + return Base.Broadcast.Broadcasted(bc.style, bc.f, bc.args, axes) +end + @inline function Base.copyto!( dest::FieldVector, bc::Base.Broadcast.Broadcasted{FieldVectorStyle}, diff --git a/test/Fields/field.jl b/test/Fields/field.jl index 9ace1c43c6..90011343da 100644 --- a/test/Fields/field.jl +++ b/test/Fields/field.jl @@ -278,6 +278,42 @@ end @test Y.k.z === 3.0 end +# https://github.com/CliMA/ClimaCore.jl/issues/1465 +@testset "Diagonal FieldVector broadcast expressions" begin + FT = Float64 + cspace = TU.CenterExtrudedFiniteDifferenceSpace(FT) + fspace = TU.FaceExtrudedFiniteDifferenceSpace(FT) + cx = Fields.coordinate_field(cspace) + cy = Fields.coordinate_field(cspace) + fx = Fields.coordinate_field(fspace) + fy = Fields.coordinate_field(fspace) + Y1 = Fields.FieldVector(; x = cx, y = cy) + Y2 = Fields.FieldVector(; x = cx, y = cy) + Y3 = Fields.FieldVector(; x = cx, y = cy) + Y4 = Fields.FieldVector(; x = cx, y = cy) + Z = Fields.FieldVector(; x = fx, y = fy) + function test_fv_allocations!(X1, X2, X3, X4) + @. X1 += X2 * X3 + X4 + return nothing + end + test_fv_allocations!(Y1, Y2, Y3, Y4) + p_allocated = @allocated test_fv_allocations!(Y1, Y2, Y3, Y4) + @test p_allocated == 0 + + bc1 = Base.broadcasted( + :-, + Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y2)), + Base.broadcasted(:*, 3, Y3), + ) + bc2 = Base.broadcasted( + :-, + Base.broadcasted(:+, Y1, Base.broadcasted(:*, 2, Y1)), + Base.broadcasted(:*, 3, Z), + ) + @test Fields.is_diagonal_bc(bc1) + @test !Fields.is_diagonal_bc(bc2) +end + function call_getcolumn(fv, colidx) @allowscalar fvcol = fv[colidx] nothing