Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify Broadcast object computations #49395

Merged
merged 1 commit into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,28 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle})

struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted
style::Style
f::F
args::Args
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`)
end

Broadcasted(f::F, args::Args, axes=nothing) where {F, Args<:Tuple} =
Broadcasted{typeof(combine_styles(args...))}(f, args, axes)
function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Args<:Tuple}
# using Core.Typeof rather than F preserves inferrability when f is a type
Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes)
Broadcasted(style::Union{Nothing,BroadcastStyle}, f::Tuple, args::Tuple) = error() # disambiguation: tuple is not callable
function Broadcasted(style::Union{Nothing,BroadcastStyle}, f::F, args::Tuple, axes=nothing) where {F}
# using Core.Typeof rather than F preserves inferrability when f is a type
return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}(style, f, args, axes)
end

function Broadcasted(f::F, args::Tuple, axes=nothing) where {F}
Broadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes)
end

function Broadcasted{Style}(f::F, args, axes=nothing) where {Style, F}
return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}(Style()::Style, f, args, axes)
end

function Broadcasted{Style,Axes,F,Args}(f, args, axes) where {Style,Axes,F,Args}
return new{Style, Axes, F, Args}(Style()::Style, f, args, axes)
end
end

struct AndAnd end
Expand All @@ -194,16 +206,16 @@ function broadcasted(::OrOr, a, bc::Broadcasted)
broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...)
end

Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{<:Any,Axes,F,Args}) where {NewStyle,Axes,F,Args} =
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)::Broadcasted{NewStyle,Axes,F,Args}

function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
print(io, Broadcasted)
# Only show the style parameter if we have a set of axes — representing an instantiated
# "outermost" Broadcasted. The styles of nested Broadcasteds represent an intermediate
# computation that is not relevant for dispatch, confusing, and just extra line noise.
bc.axes isa Tuple && print(io, '{', Style, '}')
print(io, '(', bc.f, ", ", bc.args, ')')
bc.axes isa Tuple && print(io, "{", Style, "}")
print(io, "(", bc.f, ", ", bc.args, ")")
nothing
end

Expand Down Expand Up @@ -231,7 +243,7 @@ BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} =
throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned"))

argtype(::Type{Broadcasted{Style,Axes,F,Args}}) where {Style,Axes,F,Args} = Args
argtype(::Type{BC}) where {BC<:Broadcasted} = fieldtype(BC, :args)
argtype(bc::Broadcasted) = argtype(typeof(bc))

@inline Base.eachindex(bc::Broadcasted) = _eachindex(axes(bc))
Expand Down Expand Up @@ -262,7 +274,7 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
end

Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2))
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, :args))
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N

_maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T))))
Expand All @@ -289,14 +301,14 @@ Custom [`BroadcastStyle`](@ref)s may override this default in cases where it is
to compute and verify the resulting `axes` on-demand, leaving the `axis` field
of the `Broadcasted` object empty (populated with [`nothing`](@ref)).
"""
@inline function instantiate(bc::Broadcasted{Style}) where {Style}
@inline function instantiate(bc::Broadcasted)
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
axes = combine_axes(bc.args...)
else
axes = bc.axes
check_broadcast_axes(axes, bc.args...)
end
return Broadcasted{Style}(bc.f, bc.args, axes)
return Broadcasted(bc.style, bc.f, bc.args, axes)
end
instantiate(bc::Broadcasted{<:AbstractArrayStyle{0}}) = bc
# Tuples don't need axes, but when they have axes (for .= assignment), we need to check them (#33020)
Expand Down Expand Up @@ -325,7 +337,7 @@ becomes
This is an optional operation that may make custom implementation of broadcasting easier in
some cases.
"""
function flatten(bc::Broadcasted{Style}) where {Style}
function flatten(bc::Broadcasted)
isflat(bc) && return bc
# concatenate the nested arguments into {a, b, c, d}
args = cat_nested(bc)
Expand All @@ -341,7 +353,7 @@ function flatten(bc::Broadcasted{Style}) where {Style}
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
return Broadcasted(bc.style, newf, args, bc.axes)
end
end

Expand Down Expand Up @@ -895,11 +907,11 @@ materialize(x) = x
return materialize!(dest, instantiate(Broadcasted(identity, (x,), axes(dest))))
end

@inline function materialize!(dest, bc::Broadcasted{Style}) where {Style}
@inline function materialize!(dest, bc::Broadcasted{<:Any})
return materialize!(combine_styles(dest, bc), dest, bc)
end
@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{Style}) where {Style}
return copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{<:Any})
return copyto!(dest, instantiate(Broadcasted(bc.style, bc.f, bc.args, axes(dest))))
end

## general `copy` methods
Expand All @@ -909,7 +921,7 @@ copy(bc::Broadcasted{<:Union{Nothing,Unknown}}) =

const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}

@inline function copy(bc::Broadcasted{Style}) where {Style}
@inline function copy(bc::Broadcasted)
ElType = combine_eltypes(bc.f, bc.args)
if Base.isconcretetype(ElType)
# We can trust it and defer to the simpler `copyto!`
Expand Down Expand Up @@ -968,7 +980,7 @@ broadcast_unalias(::Nothing, src) = src
# Preprocessing a `Broadcasted` does two things:
# * unaliases any arguments from `dest`
# * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices
@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes)
@inline preprocess(dest, bc::Broadcasted) = Broadcasted(bc.style, bc.f, preprocess_args(dest, bc.args), bc.axes)
preprocess(dest, x) = extrude(broadcast_unalias(dest, x))

@inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...)
Expand Down Expand Up @@ -1038,11 +1050,11 @@ ischunkedbroadcast(R, args::Tuple{<:BroadcastedChunkableOp,Vararg{Any}}) = ischu
ischunkedbroadcast(R, args::Tuple{}) = true

# Convert compatible functions to chunkable ones. They must also be green-lighted as ChunkableOps
liftfuncs(bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(sign)}) where {Style} = Broadcasted{Style}(identity, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(!)}) where {Style} = Broadcasted{Style}(~, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(*)}) where {Style} = Broadcasted{Style}(&, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(==)}) where {Style} = Broadcasted{Style}((~)∘(xor), map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,<:Any}) = Broadcasted(bc.style, bc.f, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(sign)}) = Broadcasted(bc.style, identity, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(!)}) = Broadcasted(bc.style, ~, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(*)}) = Broadcasted(bc.style, &, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(==)}) = Broadcasted(bc.style, (~)∘(xor), map(liftfuncs, bc.args), bc.axes)
liftfuncs(x) = x

liftchunks(::Tuple{}) = ()
Expand Down Expand Up @@ -1315,26 +1327,26 @@ end
return broadcasted((args...) -> f(args...; kwargs...), args...)
end
end
@inline function broadcasted(f, args...)
@inline function broadcasted(f::F, args...) where {F}
args′ = map(broadcastable, args)
broadcasted(combine_styles(args′...), f, args′...)
end
# Due to the current Type{T}/DataType specialization heuristics within Tuples,
# the totally generic varargs broadcasted(f, args...) method above loses Type{T}s in
# mapping broadcastable across the args. These additional methods with explicit
# arguments ensure we preserve Type{T}s in the first or second argument position.
@inline function broadcasted(f, arg1, args...)
@inline function broadcasted(f::F, arg1, args...) where {F}
arg1′ = broadcastable(arg1)
args′ = map(broadcastable, args)
broadcasted(combine_styles(arg1′, args′...), f, arg1′, args′...)
end
@inline function broadcasted(f, arg1, arg2, args...)
@inline function broadcasted(f::F, arg1, arg2, args...) where {F}
arg1′ = broadcastable(arg1)
arg2′ = broadcastable(arg2)
args′ = map(broadcastable, args)
broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...)
end
@inline broadcasted(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args)
@inline broadcasted(style::BroadcastStyle, f::F, args...) where {F} = Broadcasted(style, f, args)

"""
BroadcastFunction{F} <: Function
Expand Down
2 changes: 1 addition & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ let
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{Broadcast.ArrayConflict}
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

@test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}()
@test @inferred(Base.IteratorSize(Broadcast.broadcasted(+, (1,2,3), a1, zeros(3,3,3)))) === Base.HasShape{3}()

# inference on nested
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))
Expand Down