Skip to content

Commit

Permalink
Very WIP: RFC: Add an alternative implementation of closures
Browse files Browse the repository at this point in the history
 # Overview

This PR is a sketch implementation of an alternative mechanism of
implementing closures. It is designed to be complimenatry to the
existing closure mechanism and makes some different trade offs.
The motivation for this mechanism comes primarily from closure-based
AD tools like Zygote, but I'm expecting it will find other use cases
as well. It's a little hard to name this, because it's just another
way to implement closures. I discussed this with jeff and options
that were considered were "Arrow Closures" or "Non-Nominal Closures",
but for now I'm just calling them "Yet Another Kind of Closure" (YAKC,
pronounced yak-c, rhymes with yahtzee). This PR is in very early
stages, so in order to explain what this is and what this does, I
may describe features that are not yet implemented. See the end of the
commit message to see the current status.

 # Motivation
 ## Optimization across closure boundaries

Consider the following situation (type annotations are inference
results, not type asserts)
```
function foo()
    a = expensive_but_effect_free()::Any
    b = something()::Float64
    ()->isa(b, Float64) ? return a : return nothing
end
```

now, the traditional closure mechanism will lower this to:

```
struct ###{T, S}
   a::T
   b::S
end
(x::###{T,S}) = isa(b, Float64) ? return a : return nothing
function foo()
    a = expensive_but_effect_free()::Any
    b = something()::Float64
    new(a, b)
end
```

the problem with this is apparent: Even though after inference,
we know that `a` is unused in the closure (and thus would be
able to delete the expensive call were it not for the capture),
we may not delete it, simply because we need to satisfy the full
capture list of the closure. Ideally, we would like to have a mechanism
where the optimizer may modify the capture list of a closure in
response to information it discovers.

 ## Closures from Casette transforms

Compiler passes like Zygote would like to generate new closures
from untyped IR (i.e. after the frontend runs) (and in the future
potentially typed IR also). We currently do not have a great mechanism
to support this (it is somewhat possible by constructing an object
that redoes the primal analysis, but it's awkward at the very least).
This provides a very straightforward implementation of this feature.

 # Mechanism

The primary concept introduced by this PR is the `YAKC` type, defined
as follows:
```
struct YAKC{A <: Tuple, R}
     env::Any
     ci::CodeInfo
 end

 function (y::YAKC{A, R})(args...) where {A,R}
     typeassert(args, A)
     ccall(:jl_invoke_yakc, Any, (Any, Any), y, args)::R
 end
```
The dynamic semantics are that calling the yakc will run whatever code
is stored in the `.ci` object, using `env` as the self argument. This
is augmented by special support in inference and the optimizer to
co-optimize yakcs that appear in bodies of functions along with their
containing functions in order to enable things like cross-closure DCE.

Note that argument types and return types are explicitly specified,
rather than inferred. The reason for this is to prevent YAKCs from
participating in inference cycles and to allow return-type information
to be provided to inference without having to look at the contained
CodeInfo (because the contents of the CodeInfo are not interprocedurally
valid and may in general depend on what the optimizer was able to
figure out about the program). This is also done with a few to an
extension where the CodeInfo inside the yakc is not generated until
later in the optimization pipeline. It would be possible to optionally
allow return-type inference of the yatc based on the unoptimized
CodeInfo (if available), but that is not currently within the scope
of my planned work in this PR.

 # Status

The PR has the bare bones support for yakcs, including some initial
support for inling and inference, though that support is known to
be incorrect and incomplete. There are also currently no nice
front-end forms. For explanatory and testing purposes, I would like to
provide a macro of the form:
```
function foo()
    a = expensive_but_effect_free()::Any
    b = something()::Float64
    @yakc ()->isa(b, Float64) ? return a : return nothing
end
```
but that is not implemented yet. At the moment, the codeinfo needs
to be spliced in manually (here we're abusing `code_lowered` to construct
us a CodeInfo of the appropriate form), e.g.:
```
julia> bar() = 1
bar (generic function with 1 method)

julia> ci = @code_lowered bar();

julia> @eval function foo1()
          f = $(Expr(:new, :(Core.YAKC{Tuple{}, Int64}), nothing, ci))
       end
foo1 (generic function with 1 method)

julia> @eval function foo2()
          f = $(Expr(:new, :(Core.YAKC{Tuple{}, Int64}), nothing, ci))
          f()
       end
foo2 (generic function with 1 method)

julia> foo1()
Core.YAKC{Tuple{},Int64}(nothing, CodeInfo(
1 ─     return 1
))

julia> foo1()()
1

julia> @code_typed foo2()
CodeInfo(
1 ─     return 1
) => Int64

julia> struct Test
           a::Int
           b::Int
       end

julia> (a::Test)() = getfield(a, 1) + getfield(a, 2)

julia> ci2 = @code_lowered Test(1, 2)();

julia> @eval function foo2()
          f = $(Expr(:new, :(Core.YAKC{Tuple{}, Int64}), (1, 2), ci2))
          f()
       end
foo2 (generic function with 1 method)

julia> @code_typed foo2()
CodeInfo(
1 ─     return 3
) => Int64
```

TODO:
 - [ ] Show that this actually helps the Zygote use case I care about
 - [ ] Frontend support
 - [ ] Better optimizations (detection of need for recursive inlining)
 - [ ] Codegen for yakcs (right now they're interpreted)
  • Loading branch information
Keno committed Sep 21, 2020
1 parent fafc0a4 commit 176534b
Show file tree
Hide file tree
Showing 32 changed files with 626 additions and 60 deletions.
3 changes: 3 additions & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ using .Libc: getpid, gethostname, time

include("env.jl")

# YAKC
include("yakc.jl")

# Concurrency
include("linked_list.jl")
include("condition.jl")
Expand Down
10 changes: 10 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -785,4 +785,14 @@ Integer(x::Union{Float32, Float64}) = Int(x)
# The internal jl_parse which will call into Core._parse if not `nothing`.
_parse = nothing

# YAKC Definition
#=
struct YAKC{A <: Tuple, R}
env::Any
ci::CodeInfo
fptr1
fptr
end
=#

ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true)
53 changes: 49 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
edges = Any[]
nonbot = 0 # the index of the only non-Bottom inference result if > 0
seen = 0 # number of signatures actually inferred
istoplevel = sv.linfo.def isa Module
istoplevel = sv.linfo !== nothing && sv.linfo.def isa Module
multiple_matches = napplicable > 1

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
Expand Down Expand Up @@ -337,7 +337,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
sv_method2 isa Method || (sv_method2 = nothing) # Union{Method, Nothing}
while !(infstate === nothing)
infstate = infstate::InferenceState
if method === infstate.linfo.def
if infstate.linfo !== nothing && method === infstate.linfo.def
if infstate.linfo.specTypes == sig
# avoid widening when detecting self-recursion
# TODO: merge call cycle and return right away
Expand Down Expand Up @@ -773,6 +773,24 @@ function argtype_tail(argtypes::Vector{Any}, i::Int)
return argtypes[i:n]
end

function _yakc_tfunc(@nospecialize(arg), @nospecialize(lb), @nospecialize(ub),
@nospecialize(env), @nospecialize(ci), linfo::MethodInstance)
argt, argt_exact = instanceof_tfunc(arg)
lbt, lb_exact = instanceof_tfunc(lb)
if !lb_exact
lbt = Union{}
end

ubt, ub_exact = instanceof_tfunc(ub)

t = Core.YAKC{argt_exact ? argt : <:argt}
t = t{(lbt == ubt && ub_exact) ? ubt : T} where lbt<:T<:ubt

isa(ci, Const) || return t

PartialYAKC(t, env, linfo, ci.val)
end

function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::Union{Nothing,Vector{Any}},
argtypes::Vector{Any}, sv::InferenceState, max_methods::Int)
la = length(argtypes)
Expand All @@ -790,6 +808,10 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
ty = typeintersect(ty, cnd.elsetype)
end
return tmerge(tx, ty)
elseif f === Core._yakc
la == 6 || return Union{}
return _yakc_tfunc(argtypes[2], argtypes[3], argtypes[4],
argtypes[5], argtypes[6], sv.linfo)
end
rt = builtin_tfunction(interp, f, argtypes[2:end], sv)
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
Expand Down Expand Up @@ -1003,6 +1025,27 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return abstract_call_gf_by_type(interp, f, argtypes, atype, sv, max_methods)
end

function abstract_call_yakc(interp::AbstractInterpreter, yakc::PartialYAKC, argtypes::Vector{Any}, sv::InferenceState)
if isa(yakc.ci, CodeInfo)
nargtypes = argtypes[2:end]
pushfirst!(nargtypes, yakc.env)
result = InferenceResult(Core.YAKC, nargtypes)
state = InferenceState(result, copy(yakc.ci), false, interp)
typeinf_local(interp, state)
finish(state, interp)
yakc.ci = result.src
return CallMeta(result.result, false)
elseif isa(yakc.ci, OptimizationState)
return CallMeta(yakc.ci.src.rettype, nothing)
else
nargtypes = argtypes[2:end]
pushfirst!(nargtypes, Core.YAKC)
sig = argtypes_to_type(nargtypes)
rt, edge = abstract_call_method(interp, yakc.ci::Method, sig, Core.svec(), false, sv)
return CallMeta(rt, edge)
end
end

# call where the function is any lattice element
function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
Expand All @@ -1014,6 +1057,8 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
f = ft.parameters[1]
elseif isa(ft, DataType) && isdefined(ft, :instance)
f = ft.instance
elseif isa(ft, PartialYAKC)
return abstract_call_yakc(interp, ft, argtypes, sv)
else
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
Expand Down Expand Up @@ -1326,14 +1371,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
elseif isa(stmt, ReturnNode)
pc´ = n + 1
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) && !isa(rt, PartialYAKC)
# only propagate information we know we can store
# and is valid inter-procedurally
rt = widenconst(rt)
end
if tchanged(rt, frame.bestguess)
# new (wider) return type for frame
frame.bestguess = tmerge(frame.bestguess, rt)
frame.bestguess = frame.bestguess === NOT_FOUND ? rt : tmerge(frame.bestguess, rt)
for (caller, caller_pc) in frame.cycle_backedges
# notify backedges of updated type information
typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before
Expand Down
32 changes: 21 additions & 11 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const LineNum = Int
mutable struct InferenceState
params::InferenceParams
result::InferenceResult # remember where to put the result
linfo::MethodInstance
linfo::Union{MethodInstance, Nothing}
sptypes::Vector{Any} # types of static parameter
slottypes::Vector{Any}
mod::Module
Expand Down Expand Up @@ -37,6 +37,8 @@ mutable struct InferenceState
callers_in_cycle::Vector{InferenceState}
parent::Union{Nothing, InferenceState}

has_yakcs::Bool

# TODO: move these to InferenceResult / Params?
cached::Bool
limited::Bool
Expand All @@ -57,9 +59,23 @@ mutable struct InferenceState
cached::Bool, interp::AbstractInterpreter)
linfo = result.linfo
code = src.code::Array{Any,1}
toplevel = !isa(linfo.def, Method)

sp = sptypes_from_meth_instance(linfo::MethodInstance)
if !isa(linfo, Nothing)
toplevel = !isa(linfo.def, Method)
sp = sptypes_from_meth_instance(linfo::MethodInstance)
if !toplevel
meth = linfo.def
inmodule = meth.module
else
inmodule = linfo.def::Module
end
else
linfo = nothing
toplevel = true
inmodule = Core
sp = Any[]
end
code = src.code::Array{Any,1}

nssavalues = src.ssavaluetypes::Int
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
Expand Down Expand Up @@ -93,13 +109,6 @@ mutable struct InferenceState
W = BitSet()
push!(W, 1) #initial pc to visit

if !toplevel
meth = linfo.def
inmodule = meth.module
else
inmodule = linfo.def::Module
end

valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
frame = new(
Expand All @@ -112,7 +121,7 @@ mutable struct InferenceState
ssavalue_uses, throw_blocks,
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
#=parent=#nothing, #= has_yakcs =# false,
cached, false, false, false,
CachedMethodTable(method_table(interp)),
interp)
Expand Down Expand Up @@ -242,6 +251,7 @@ end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(li::MethodInstance, caller::InferenceState)
caller.linfo !== nothing || return # don't add backends to yakcs
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
if caller.stmt_edges[caller.currpc] === nothing
caller.stmt_edges[caller.currpc] = []
Expand Down
9 changes: 8 additions & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCache
end

mutable struct OptimizationState
linfo::MethodInstance
linfo::Union{MethodInstance, Nothing}
src::CodeInfo
stmt_info::Vector{Any}
mod::Module
Expand Down Expand Up @@ -311,6 +311,13 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
ftyp = argextype(farg, src, sptypes, slottypes)
end
end
# Give calls to YAKCs zero cost. The plan is for these to be a single
# indirect call so have very little cost. On the other hand, there
# is enormous benefit to inlining these into a function where we can
# see the definition of the YAKC. Perhaps this should even be negative
if widenconst(ftyp) <: Core.YAKC
return 0
end
f = singleton_type(ftyp)
if isa(f, IntrinsicFunction)
iidx = Int(reinterpret(Int32, f::IntrinsicFunction)) + 1
Expand Down
64 changes: 55 additions & 9 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,33 @@ function normalize(@nospecialize(stmt), meta::Vector{Any})
return stmt
end

function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, nargs::Int, sv::OptimizationState)
function add_yakc_argtypes!(argtypes, t)
dt = unwrap_unionall(t)
dt1 = unwrap_unionall(dt.parameters[1])
if isa(dt1, TypeVar) || isa(dt1.parameters[1], TypeVar)
push!(argtypes, Any)
else
TT = dt1.parameters[1]
if isa(TT, Union)
TT = tuplemerge(TT.a, TT.b)
end
for p in TT.parameters
push!(argtypes, rewrap_unionall(p, t))
end
end
end


function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, nargs::Int, sv::OptimizationState, slottypes=sv.slottypes, stmtinfo=sv.stmt_info)
# Go through and add an unreachable node after every
# Union{} call. Then reindex labels.
idx = 1
oldidx = 1
changemap = fill(0, length(code))
labelmap = coverage ? fill(0, length(code)) : changemap
prevloc = zero(eltype(ci.codelocs))
stmtinfo = sv.stmt_info
stmtinfo = copy(stmtinfo)
yakcs = IRCode[]
while idx <= length(code)
codeloc = ci.codelocs[idx]
if coverage && codeloc != prevloc && codeloc != 0
Expand All @@ -57,7 +75,22 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
idx += 1
prevloc = codeloc
end
if code[idx] isa Expr && ci.ssavaluetypes[idx] === Union{}
stmt = code[idx]
if isexpr(stmt, :(=))
stmt = stmt.args[2]
end
ssat = ci.ssavaluetypes[idx]
if isa(ssat, PartialYAKC) && isexpr(stmt, :call)
ft = argextype(stmt.args[1], ci, sv.sptypes)
# Pre-convert any YAKC objects
if isa(ft, Const) && ft.val === Core._yakc && isa(ssat.ci, OptimizationState)
yakc_ir = make_ir(ssat.ci.src, 0, ssat.ci)
push!(yakcs, yakc_ir)
stmt.head = :new_yakc
push!(stmt.args, length(yakcs))
end
end
if stmt isa Expr && ssat === Union{}
if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
# insert unreachable in the same basic block after the current instruction (splitting it)
insert!(code, idx + 1, ReturnNode())
Expand Down Expand Up @@ -105,7 +138,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
cfg = compute_basic_blocks(code)
types = Any[]
stmts = InstructionStream(code, types, stmtinfo, ci.codelocs, flags)
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), sv.slottypes, meta, sv.sptypes)
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), slottypes, meta, sv.sptypes, yakcs)
return ir
end

Expand All @@ -117,24 +150,37 @@ function slot2reg(ir::IRCode, ci::CodeInfo, nargs::Int, sv::OptimizationState)
return ir
end

function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
preserve_coverage = coverage_enabled(sv.mod)
ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, nargs, sv)
function compact_all!(ir::IRCode)
length(ir.stmts) == 0 && return ir
for i in 1:length(ir.yakcs)
ir.yakcs[i] = compact_all!(ir.yakcs[i])
end
compact!(ir)
end

function make_ir(ci::CodeInfo, nargs::Int, sv::OptimizationState)
ir = convert_to_ircode(ci, copy_exprargs(ci.code), coverage_enabled(sv.mod), nargs, sv)
ir = slot2reg(ir, ci, nargs, sv)
ir
end

function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
ir = make_ir(ci, nargs, sv)
#@Base.show ("after_construct", ir)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
@timeit "compact 1" ir = compact!(ir)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
#@timeit "verify 2" verify_ir(ir)
ir = compact!(ir)
ir = compact_all!(ir)
#@Base.show ("before_sroa", ir)
@timeit "SROA" ir = getfield_elim_pass!(ir)
ir = yakc_optim_pass!(ir)
#@Base.show ir.new_nodes
#@Base.show ("after_sroa", ir)
ir = adce_pass!(ir)
#@Base.show ("after_adce", ir)
@timeit "type lift" ir = type_lift_pass!(ir)
@timeit "compact 3" ir = compact!(ir)
@timeit "compact 3" ir = compact_all!(ir)
#@Base.show ir
if JLOptions().debug_level == 2
@timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable))
Expand Down
Loading

0 comments on commit 176534b

Please sign in to comment.