diff --git a/base/broadcast.jl b/base/broadcast.jl index b89826bb22bb39..cb23f6523ac2ff 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -314,9 +314,8 @@ some cases. """ function flatten(bc::Broadcasted{Style}) where {Style} isflat(bc) && return bc - # concatenate the nested arguments into {a, b, c, d} - # args = cat_nested(bc) - # build a function `makeargs` that takes a "flat" argument list and + # 1. concatenate the nested arguments into {a, b, c, d} + # 2. build a function `makeargs` that takes a "flat" argument list and # and creates the appropriate input arguments for `f`, e.g., # makeargs = (w, x, y, z) -> (w, g(x, y), z) # @@ -329,8 +328,7 @@ function flatten(bc::Broadcasted{Style}) where {Style} @inline function (args::Vararg{Any,N}) f(makeargs(args...)...) end - newf = _make(args) - return Broadcasted{Style}(newf, args, bc.axes) + return Broadcasted{Style}(_make(args), args, bc.axes) end end @@ -340,18 +338,13 @@ _isflat(args::NestedTuple) = false _isflat(args::Tuple) = _isflat(tail(args)) _isflat(args::Tuple{}) = true -cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...) -cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...) -cat_nested() = () - """ make_makeargs(args::Tuple, t::Tuple) -> Function, Tuple Each element of `t` is one (consecutive) node in a broadcast tree. `args` contains the rest arguments on the "right" side of `t`. The jobs of `make_makeargs` are: - 1. append the flattened arguments in `t` at the beginning of `args`, i.e. - `(cat_nested(t)..., args...)` + 1. append the flattened arguments in `t` at the beginning of `args`. 2. return a function that takes in flattened argument list and returns a tuple (each entry corresponding to an entry in `t`, having evaluated the corresponding element in the broadcast tree). diff --git a/test/broadcast.jl b/test/broadcast.jl index 18534c6bab4e76..c17c90b21ded91 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -772,16 +772,19 @@ end # issue #27988: inference of Broadcast.flatten using .Broadcast: Broadcasted -let +let _cat_nested(bc) = Broadcast.flatten(bc).args bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5)))) - @test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5) + @test @inferred(_cat_nested(bc)) == (1,2,3,4,5) @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62 bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5)))) - @test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5) + @test @inferred(_cat_nested(bc)) == (1,2.0,2.5,3,4,5) @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8 # 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3 - bc = Base.Broadcast.Broadcasted(+, (Base.Broadcast.Broadcasted(+, (Base.Broadcast.Broadcasted(-, (Base.Broadcast.Broadcasted(*, (1, 1)), Base.Broadcast.Broadcasted(*, (1, Base.Broadcast.Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{2}}(Val{2}()))))))), Base.Broadcast.Broadcasted(*, (1, 1)))), Base.Broadcast.Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}()))))) + bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}()))))) @test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2 + # @. 1 + 1 * (1 + 1 + 1 + 1) + bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1)))))) + @test @inferred(_cat_nested(bc)) == (1,1,1,1,1,1) # `cat_nested` failed to infer this end let