From 5d0a527bbb9948f2ea0dfd5770dfd45cfd4aa018 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Mon, 17 Apr 2023 16:24:55 -0400 Subject: [PATCH] simplify Broadcast object computations Code should normally preserve values, not the types of values. This ensures the user can define styles with metadata, and requires less type-parameter-based programming, but rather can focus on the values. --- base/broadcast.jl | 70 +++++++++++++++++++++++++++-------------------- test/broadcast.jl | 2 +- 2 files changed, 42 insertions(+), 30 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index d86b5cd92e02f..955a5652353d7 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -167,16 +167,28 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} = # copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle}) struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted + style::Style f::F args::Args axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`) -end -Broadcasted(f::F, args::Args, axes=nothing) where {F, Args<:Tuple} = - Broadcasted{typeof(combine_styles(args...))}(f, args, axes) -function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Args<:Tuple} - # using Core.Typeof rather than F preserves inferrability when f is a type - Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes) + Broadcasted(style::Union{Nothing,BroadcastStyle}, f::Tuple, args::Tuple) = error() # disambiguation: tuple is not callable + function Broadcasted(style::Union{Nothing,BroadcastStyle}, f::F, args::Tuple, axes=nothing) where {F} + # using Core.Typeof rather than F preserves inferrability when f is a type + return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}(style, f, args, axes) + end + + function Broadcasted(f::F, args::Tuple, axes=nothing) where {F} + Broadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes) + end + + function Broadcasted{Style}(f::F, args, axes=nothing) where {Style, F} + return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}(Style()::Style, f, args, axes) + end + + function Broadcasted{Style,Axes,F,Args}(f, args, axes) where {Style,Axes,F,Args} + return new{Style, Axes, F, Args}(Style()::Style, f, args, axes) + end end struct AndAnd end @@ -194,7 +206,7 @@ function broadcasted(::OrOr, a, bc::Broadcasted) broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...) end -Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} = +Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{<:Any,Axes,F,Args}) where {NewStyle,Axes,F,Args} = Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)::Broadcasted{NewStyle,Axes,F,Args} function Base.show(io::IO, bc::Broadcasted{Style}) where {Style} @@ -202,8 +214,8 @@ function Base.show(io::IO, bc::Broadcasted{Style}) where {Style} # Only show the style parameter if we have a set of axes — representing an instantiated # "outermost" Broadcasted. The styles of nested Broadcasteds represent an intermediate # computation that is not relevant for dispatch, confusing, and just extra line noise. - bc.axes isa Tuple && print(io, '{', Style, '}') - print(io, '(', bc.f, ", ", bc.args, ')') + bc.axes isa Tuple && print(io, "{", Style, "}") + print(io, "(", bc.f, ", ", bc.args, ")") nothing end @@ -231,7 +243,7 @@ BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style() BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} = throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned")) -argtype(::Type{Broadcasted{Style,Axes,F,Args}}) where {Style,Axes,F,Args} = Args +argtype(::Type{BC}) where {BC<:Broadcasted} = fieldtype(BC, :args) argtype(bc::Broadcasted) = argtype(typeof(bc)) @inline Base.eachindex(bc::Broadcasted) = _eachindex(axes(bc)) @@ -262,7 +274,7 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s) end Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}() -Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2)) +Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, :args)) Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N _maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T)))) @@ -289,14 +301,14 @@ Custom [`BroadcastStyle`](@ref)s may override this default in cases where it is to compute and verify the resulting `axes` on-demand, leaving the `axis` field of the `Broadcasted` object empty (populated with [`nothing`](@ref)). """ -@inline function instantiate(bc::Broadcasted{Style}) where {Style} +@inline function instantiate(bc::Broadcasted) if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style}) axes = combine_axes(bc.args...) else axes = bc.axes check_broadcast_axes(axes, bc.args...) end - return Broadcasted{Style}(bc.f, bc.args, axes) + return Broadcasted(bc.style, bc.f, bc.args, axes) end instantiate(bc::Broadcasted{<:AbstractArrayStyle{0}}) = bc # Tuples don't need axes, but when they have axes (for .= assignment), we need to check them (#33020) @@ -325,7 +337,7 @@ becomes This is an optional operation that may make custom implementation of broadcasting easier in some cases. """ -function flatten(bc::Broadcasted{Style}) where {Style} +function flatten(bc::Broadcasted) isflat(bc) && return bc # concatenate the nested arguments into {a, b, c, d} args = cat_nested(bc) @@ -341,7 +353,7 @@ function flatten(bc::Broadcasted{Style}) where {Style} newf = @inline function(args::Vararg{Any,N}) where N f(makeargs(args...)...) end - return Broadcasted{Style}(newf, args, bc.axes) + return Broadcasted(bc.style, newf, args, bc.axes) end end @@ -895,11 +907,11 @@ materialize(x) = x return materialize!(dest, instantiate(Broadcasted(identity, (x,), axes(dest)))) end -@inline function materialize!(dest, bc::Broadcasted{Style}) where {Style} +@inline function materialize!(dest, bc::Broadcasted{<:Any}) return materialize!(combine_styles(dest, bc), dest, bc) end -@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{Style}) where {Style} - return copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest)))) +@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{<:Any}) + return copyto!(dest, instantiate(Broadcasted(bc.style, bc.f, bc.args, axes(dest)))) end ## general `copy` methods @@ -909,7 +921,7 @@ copy(bc::Broadcasted{<:Union{Nothing,Unknown}}) = const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict} -@inline function copy(bc::Broadcasted{Style}) where {Style} +@inline function copy(bc::Broadcasted) ElType = combine_eltypes(bc.f, bc.args) if Base.isconcretetype(ElType) # We can trust it and defer to the simpler `copyto!` @@ -968,7 +980,7 @@ broadcast_unalias(::Nothing, src) = src # Preprocessing a `Broadcasted` does two things: # * unaliases any arguments from `dest` # * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices -@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes) +@inline preprocess(dest, bc::Broadcasted) = Broadcasted(bc.style, bc.f, preprocess_args(dest, bc.args), bc.axes) preprocess(dest, x) = extrude(broadcast_unalias(dest, x)) @inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...) @@ -1038,11 +1050,11 @@ ischunkedbroadcast(R, args::Tuple{<:BroadcastedChunkableOp,Vararg{Any}}) = ischu ischunkedbroadcast(R, args::Tuple{}) = true # Convert compatible functions to chunkable ones. They must also be green-lighted as ChunkableOps -liftfuncs(bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, map(liftfuncs, bc.args), bc.axes) -liftfuncs(bc::Broadcasted{Style,<:Any,typeof(sign)}) where {Style} = Broadcasted{Style}(identity, map(liftfuncs, bc.args), bc.axes) -liftfuncs(bc::Broadcasted{Style,<:Any,typeof(!)}) where {Style} = Broadcasted{Style}(~, map(liftfuncs, bc.args), bc.axes) -liftfuncs(bc::Broadcasted{Style,<:Any,typeof(*)}) where {Style} = Broadcasted{Style}(&, map(liftfuncs, bc.args), bc.axes) -liftfuncs(bc::Broadcasted{Style,<:Any,typeof(==)}) where {Style} = Broadcasted{Style}((~)∘(xor), map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{<:Any,<:Any,<:Any}) = Broadcasted(bc.style, bc.f, map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(sign)}) = Broadcasted(bc.style, identity, map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(!)}) = Broadcasted(bc.style, ~, map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(*)}) = Broadcasted(bc.style, &, map(liftfuncs, bc.args), bc.axes) +liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(==)}) = Broadcasted(bc.style, (~)∘(xor), map(liftfuncs, bc.args), bc.axes) liftfuncs(x) = x liftchunks(::Tuple{}) = () @@ -1315,7 +1327,7 @@ end return broadcasted((args...) -> f(args...; kwargs...), args...) end end -@inline function broadcasted(f, args...) +@inline function broadcasted(f::F, args...) where {F} args′ = map(broadcastable, args) broadcasted(combine_styles(args′...), f, args′...) end @@ -1323,18 +1335,18 @@ end # the totally generic varargs broadcasted(f, args...) method above loses Type{T}s in # mapping broadcastable across the args. These additional methods with explicit # arguments ensure we preserve Type{T}s in the first or second argument position. -@inline function broadcasted(f, arg1, args...) +@inline function broadcasted(f::F, arg1, args...) where {F} arg1′ = broadcastable(arg1) args′ = map(broadcastable, args) broadcasted(combine_styles(arg1′, args′...), f, arg1′, args′...) end -@inline function broadcasted(f, arg1, arg2, args...) +@inline function broadcasted(f::F, arg1, arg2, args...) where {F} arg1′ = broadcastable(arg1) arg2′ = broadcastable(arg2) args′ = map(broadcastable, args) broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...) end -@inline broadcasted(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args) +@inline broadcasted(style::BroadcastStyle, f::F, args...) where {F} = Broadcasted(style, f, args) """ BroadcastFunction{F} <: Function diff --git a/test/broadcast.jl b/test/broadcast.jl index 41ca604cb50e4..87858dd0f08fc 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -880,7 +880,7 @@ let @test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{Broadcast.ArrayConflict} @test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}} - @test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}() + @test @inferred(Base.IteratorSize(Broadcast.broadcasted(+, (1,2,3), a1, zeros(3,3,3)))) === Base.HasShape{3}() # inference on nested bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))