-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
broadcast noop not inferred #56
Comments
just leaving this here as a starting point for myself later: julia> using Cassette, Base.Broadcast
julia> x = rand(1);
julia> Cassette.@context Ctx
julia> @code_typed Cassette.overdub(Ctx(), t -> Base.Broadcast.combine_axes(t...), (x, x))
CodeInfo(
25 1 ─ getfield(%%args, 1) │
│ getfield(%%args, 2) │
26 │ getfield(%%args, 1) │
│ getfield(%%args, 2) │
└── goto 3 if not false │
2 ─ nothing │
29 3 ┄ getfield(%%args, 1) │
│ %8 = getfield(%%args, 2)::Tuple{Array{Float64,1},Array{Float64,1}} │
│ %9 = :(Base.Broadcast)::Module │╻ #6
└── goto 4 if not false ││╻ overdub
4 ┄ %11 = π (%9, Module) │││╻ getproperty
│ %12 = π (:combine_axes, Symbol) ││││
│ %13 = :(Base.getfield)::typeof(getfield) ││││
└── goto 5 if not false ││││╻ overdub
5 ┄ Base.getfield(Cassette, :execute) │││││╻╷╷ recurse
│ %13(%11, %12) ││││││╻ macro expansion
└── goto 6 ││││╻ overdub
6 ─ goto 7 ││││
7 ─ goto 8 ││╻ overdub
8 ─ %20 = :(Core._apply)::typeof(Core._apply) │││╻ execute
│ %21 = %20(getfield(Main, Symbol("##4#5")){Cassette.Context{nametype(Ctx),Nothing,Cassette.NoPass,Nothing,Nothing},typeof(Base.Broadcast.combine_axes)}(Cassette.Context{nametype(Ctx),Nothing,Cassette.NoPass,Nothing,Nothing}(nametype(Ctx)(), nothing, Cassette.NoPass(), nothing, nothing), Base.Broadcast.combine_axes), %8)::Any
└── goto 9 ││
31 9 ─ getfield(%%args, 1) │
│ getfield(%%args, 2) │
32 └── return %21 │
) => Any |
Okay, so the two problems here are problematic julia> Cassette.recurse_typed(Ctx(), t -> Base.Broadcast.combine_axes(t...), (rand(1), rand(1)))
1-element Array{Any,1}:
CodeInfo(
1 1 ─ :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1)) │
│ :(t = (Core.getfield)(Core.Compiler.Argument(3), 2)) │
│ %3 = :(Base.Broadcast)::Core.Compiler.Const(Base.Broadcast, false) │
│ %4 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %3, :combine_axes)::Core.Compiler.Const(Base.Broadcast.combine_axes, false)
│ %5 = Cassette.overdub(%%##recurse_context#376, Core._apply, %4, %%t)::Any
└── return %5 │
) => Any problematic julia> Cassette.recurse_typed(Ctx(), Broadcast.copy, Broadcast.instantiate(Broadcast.broadcasted(+, rand(1), rand(1))))
1-element Array{Any,1}:
CodeInfo(
1 ─ :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1)) │
│ :(bc = (Core.getfield)(Core.Compiler.Argument(3), 2)) │
⋮
│ %14 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :f)::Union{typeof(+), Tuple}
│ %15 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :args)::Union{typeof(+), Tuple}
⋮
) => Any Making both I'm okay with making The |
Argh. Fixing the However, having julia> using Cassette, Base.Broadcast
julia> Cassette.@context Ctx
julia> b = Broadcast.instantiate(Broadcast.broadcasted(+, rand(1), rand(1)))
Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}}(+, ([0.0228267], [0.173769]))
julia> Cassette.recurse_typed(Ctx(), Broadcast._broadcast_getindex, b, 1)
1-element Array{Any,1}:
CodeInfo(
1 ─ :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1)) │
│ :(bc = (Core.getfield)(Core.Compiler.Argument(3), 2)) │
│ :(I = (Core.getfield)(Core.Compiler.Argument(3), 3)) │
│ nothing │
551 │ %5 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :args)::Tuple{Array{Float64,1},Array{Float64,1}}
│ :(args = (Cassette.overdub)(Core.Compiler.Argument(2), Base.Broadcast._getindex, Core.SSAValue(5), Core.Compiler.Argument(6)))
552 │ %7 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :f)::Core.Compiler.Const(+, false)
│ %8 = Cassette.overdub(%%##recurse_context#376, Core.tuple, %7)::Core.Compiler.Const((+,), false)
│ %9 = Cassette.overdub(%%##recurse_context#376, Core._apply, Base.Broadcast._broadcast_getindex_evalf, %8, %%args)::Any
└── return %9 │
) => Any That's with julia> Cassette.recurse_typed(Ctx(), Broadcast._broadcast_getindex, b, 1)
1-element Array{Any,1}:
CodeInfo(
1 ─ :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1)) │
│ :(bc = (Core.getfield)(Core.Compiler.Argument(3), 2)) │
│ :(I = (Core.getfield)(Core.Compiler.Argument(3), 3)) │
│ nothing │
551 │ %5 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :args)::Tuple{Array{Float64,1},Array{Float64,1}}
│ :(args = (Cassette.overdub)(Core.Compiler.Argument(2), Base.Broadcast._getindex, Core.SSAValue(5), Core.Compiler.Argument(6)))
552 │ %7 = Cassette.overdub(%%##recurse_context#376, Base.getproperty, %%bc, :f)::Core.Compiler.Const(+, false)
│ %8 = Cassette.overdub(%%##recurse_context#376, Core.tuple, %7)::Core.Compiler.Const((+,), false)
│ %9 = Cassette.overdub(%%##recurse_context#376, Core._apply, Base.Broadcast._broadcast_getindex_evalf, %8, %%args)::Float64
└── return %9 │
) => Float64 If we can come up with an |
Note that if I set julia> Cassette.recurse_typed(Ctx(), Broadcast._broadcast_getindex, b, 1; optimize=true)
1-element Array{Any,1}:
CodeInfo(
⋮ # everything until here seems well-inferred
431 ┄ %746 = invoke Cassette.recurse(:($(QuoteNode(Cassette.Context{nametype(Ctx),Nothing,Cassette.NoPass,Nothing,Nothing}(nametype(Ctx)(), nothing, Cassette.NoPass(), nothing, nothing))))::Cassette.Context{nametype(Ctx),Nothing,Cassette.NoPass,Nothing,Nothing}, Base.Broadcast._broadcast_getindex_evalf::typeof(Base.Broadcast._broadcast_getindex_evalf), %742::typeof(+), %743::Float64, %744::Float64)::Any
└──── goto 432
432 ─ goto 433
433 ─ goto 434
434 ─ goto 435 overdub
435 ─ return %746
) => Any However, this call independently seems to infer fine: julia> Cassette.recurse_typed(Ctx(), Broadcast._broadcast_getindex_evalf, +, 1.0, 1.0)
1-element Array{Any,1}:
CodeInfo(
1 ─ :(#self# = (Core.getfield)(Core.Compiler.Argument(3), 1)) │
│ :(f = (Core.getfield)(Core.Compiler.Argument(3), 2)) │
│ %3 = Core.getfield(%%##recurse_arguments#377, 3)::Float64 │
│ %4 = Core.getfield(%%##recurse_arguments#377, 4)::Float64 │
│ :(args = (Core.tuple)(Core.SSAValue(3), Core.SSAValue(4))) │
│ nothing │
579 │ %7 = Cassette.overdub(%%##recurse_context#376, Core._apply, %%f, %%args)::Float64
└── return %7 │
) => Float64 Hmm.... |
The text was updated successfully, but these errors were encountered: