Skip to content

Commit

Permalink
Implement opaque closures
Browse files Browse the repository at this point in the history
This is the end result of the design process in #31253.

 # Overview

This PR implements a new kind of closure, called an `opaque closure`.
It is designed to be complimenatry to the existing closure mechanism
and makes some different trade offs. The motivation for this mechanism
comes primarily from closure-based AD tools, but I'm expecting it will
find other use cases as well. From the end user perspective, opaque
closures basically behave like regular closures, expect that they
are introduced by adding the `@opaque` macro (not part of this PR,
but will be added after). In front of the existing closure. In
particular, all scoping, capture, etc. rules are identical. For
such user written closures, the primary difference is in the
performance characteristics. In particular:

1) Passing an opaque closure to a high order function will specialize
   on the argument and return types of the closure, but not on the
   closure identity. (This also means that the opaque closure will not
   be eligible for inlining into the higher order function, unless the
   inliner can see both the definition and the call site).
2) The optimizer is allowed to modify the capture environment of the
   opaque closure (e.g. dropping unused captures, or reducing `Box`ed
   values back to value captures).

The `opaque` part of the naming comes from the notion that semantically,
nothing is supposed to inspect either the code or the capture environment
of the opaque closure, since the optimizer is allowed to choose any value
for these that preserves the behavior of calling the opaque closure itself.

 # Motivation
 ## Optimization across closure boundaries

Consider the following situation (type annotations are inference
results, not type asserts)
```
function foo()
    a = expensive_but_effect_free()::Any
    b = something()::Float64
    ()->isa(b, Float64) ? return nothing : return a
end
```

now, the traditional closure mechanism will lower this to:

```
struct ###{T, S}
   a::T
   b::S
end
(x::###{T,S}) = isa(x.b, Float64) ? return nothing : return x.a
function foo()
    a = expensive_but_effect_free()::Any
    b = something()::Float64
    new(a, b)
end
```

the problem with this is apparent: Even though (after inference),
we know that `a` is unused in the closure (and thus would be
able to delete the expensive call were it not for the capture),
we may not delete it, simply because we need to satisfy the full
capture list of the closure. Ideally, we would like to have a mechanism
where the optimizer may modify the capture list of a closure in
response to information it discovers.

 ## Closures from Casette transforms

Compiler passes like Zygote would like to generate new closures
from untyped IR (i.e. after the frontend runs) (and in the future
potentially typed IR also). We currently do not have a great mechanism
to support this. This provides a very straightforward implementation
of this feature, as opaque closures may be inserted at any point during
the compilation process (unlike types, which may only be inserted
by the frontend).

 # Mechanism

The primary concept introduced by this PR is the `OpaqueClosure{A<:Tuple, R}`
type, constructed, by the new `Core._opaque_closure` builtin, with
the following signature:

```
    _opaque_closure(argt::Type{<:Tuple}, lb::Type, ub::Type, source::CodeInfo, captures...)

Create a new OpaqueClosure taking arguments specified by the types `argt`. When called,
this opaque closure will execute the source specified in `source`. The `lb` and `ub`
arguments constrain the return type of the opaque closure. In particular, any return
value of type `Core.OpaqueClosure{argt, R} where lb<:R<:ub` is semantically valid. If
the optimizer runs, it may replace `R` by the narrowest possible type inference
was able to determine. To guarantee a particular value of `R`, set lb===ub.
```

Captures are available to the CodeInfo as `getfield` from Slot 1
(referenced by position).

 # Examples

I think the easiest way to understand opaque closures is look through
a few examples. These make use of the `@opaque` macro which isn't
implemented yet, but makes understanding the semantics easier.
Some of these examples, in currently available syntax can be seen
in test/opaque_closure.jl

```
oc_trivial() = @opaque ()::Any->1
@show oc_trivial() # ()::Any->◌
@show oc_trivial()() # 1

oc_inf() = @opaque ()->1
 # Int return type is inferred
@show oc_inf() # ()::Int->◌
@show oc_inf()() # 1

function local_call(b::Int)
    f = @opaque (a::Int)->a + b
    f(2)
end

oc_capture_opt(A) = @opaque (B::typeof(A))->ndims(A)*B
@show oc_capture_opt([1; 2]) # (::Vector{Int},)::Vector{Int}->◌
@show sizeof(oc_capture_opt([1; 2]).env) # 0
@show oc_capture_opt([1 2])([3 4]) # [6 8]
```
  • Loading branch information
Keno committed Oct 13, 2020
1 parent c91a750 commit 9cc49d3
Show file tree
Hide file tree
Showing 34 changed files with 748 additions and 67 deletions.
3 changes: 3 additions & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ using .Libc: getpid, gethostname, time

include("env.jl")

# OpaqueClosure
include("opaque_closure.jl")

# Concurrency
include("linked_list.jl")
include("condition.jl")
Expand Down
36 changes: 32 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
edges = Any[]
nonbot = 0 # the index of the only non-Bottom inference result if > 0
seen = 0 # number of signatures actually inferred
istoplevel = sv.linfo.def isa Module
istoplevel = sv.linfo !== nothing && sv.linfo.def isa Module
multiple_matches = napplicable > 1

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
Expand Down Expand Up @@ -337,7 +337,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
sv_method2 isa Method || (sv_method2 = nothing) # Union{Method, Nothing}
while !(infstate === nothing)
infstate = infstate::InferenceState
if method === infstate.linfo.def
if infstate.linfo !== nothing && method === infstate.linfo.def
if infstate.linfo.specTypes == sig
# avoid widening when detecting self-recursion
# TODO: merge call cycle and return right away
Expand Down Expand Up @@ -790,6 +790,10 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
ty = typeintersect(ty, cnd.elsetype)
end
return tmerge(tx, ty)
elseif f === Core._opaque_closure
la >= 5 || return Union{}
return _opaque_closure_tfunc(argtypes[2], argtypes[3], argtypes[4],
argtypes[5], argtypes[6:end], sv.linfo)
end
rt = builtin_tfunction(interp, f, argtypes[2:end], sv)
if f === getfield && isa(fargs, Vector{Any}) && la == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] Tuple
Expand Down Expand Up @@ -1003,6 +1007,28 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
return abstract_call_gf_by_type(interp, f, argtypes, atype, sv, max_methods)
end

function abstract_call_opaque_closure(interp::AbstractInterpreter, clos::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
if isa(clos.ci, CodeInfo)
nargtypes = argtypes[2:end]
pushfirst!(nargtypes, clos.env)
result = InferenceResult(Core.OpaqueClosure, nargtypes)
state = InferenceState(result, copy(clos.ci), false, interp)
typeinf_local(interp, state)
finish(state, interp)
result.src.src.inferred = true
clos.ci = result.src
return CallMeta(result.result, false)
elseif isa(clos.ci, OptimizationState)
return CallMeta(clos.ci.src.rettype, nothing)
else
nargtypes = argtypes[2:end]
pushfirst!(nargtypes, Core.OpaqueClosure)
sig = argtypes_to_type(nargtypes)
rt, edge = abstract_call_method(interp, clos.ci::Method, sig, Core.svec(), false, sv)
return CallMeta(rt, edge)
end
end

# call where the function is any lattice element
function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
Expand All @@ -1014,6 +1040,8 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
f = ft.parameters[1]
elseif isa(ft, DataType) && isdefined(ft, :instance)
f = ft.instance
elseif isa(ft, PartialOpaque)
return abstract_call_opaque_closure(interp, ft, argtypes, sv)
else
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
Expand Down Expand Up @@ -1326,14 +1354,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
elseif isa(stmt, ReturnNode)
pc´ = n + 1
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) && !isa(rt, PartialOpaque)
# only propagate information we know we can store
# and is valid inter-procedurally
rt = widenconst(rt)
end
if tchanged(rt, frame.bestguess)
# new (wider) return type for frame
frame.bestguess = tmerge(frame.bestguess, rt)
frame.bestguess = frame.bestguess === NOT_FOUND ? rt : tmerge(frame.bestguess, rt)
for (caller, caller_pc) in frame.cycle_backedges
# notify backedges of updated type information
typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before
Expand Down
32 changes: 21 additions & 11 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ const LineNum = Int
mutable struct InferenceState
params::InferenceParams
result::InferenceResult # remember where to put the result
linfo::MethodInstance
linfo::Union{MethodInstance, Nothing}
sptypes::Vector{Any} # types of static parameter
slottypes::Vector{Any}
mod::Module
Expand Down Expand Up @@ -37,6 +37,8 @@ mutable struct InferenceState
callers_in_cycle::Vector{InferenceState}
parent::Union{Nothing, InferenceState}

has_opaque_closures::Bool

# TODO: move these to InferenceResult / Params?
cached::Bool
limited::Bool
Expand All @@ -57,9 +59,23 @@ mutable struct InferenceState
cached::Bool, interp::AbstractInterpreter)
linfo = result.linfo
code = src.code::Array{Any,1}
toplevel = !isa(linfo.def, Method)

sp = sptypes_from_meth_instance(linfo::MethodInstance)
if !isa(linfo, Nothing)
toplevel = !isa(linfo.def, Method)
sp = sptypes_from_meth_instance(linfo::MethodInstance)
if !toplevel
meth = linfo.def
inmodule = meth.module
else
inmodule = linfo.def::Module
end
else
linfo = nothing
toplevel = true
inmodule = Core
sp = Any[]
end
code = src.code::Array{Any,1}

nssavalues = src.ssavaluetypes::Int
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
Expand Down Expand Up @@ -93,13 +109,6 @@ mutable struct InferenceState
W = BitSet()
push!(W, 1) #initial pc to visit

if !toplevel
meth = linfo.def
inmodule = meth.module
else
inmodule = linfo.def::Module
end

valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
frame = new(
Expand All @@ -112,7 +121,7 @@ mutable struct InferenceState
ssavalue_uses, throw_blocks,
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
#=parent=#nothing, #= has_opaque_closures =# false,
cached, false, false, false,
CachedMethodTable(method_table(interp)),
interp)
Expand Down Expand Up @@ -242,6 +251,7 @@ end

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(li::MethodInstance, caller::InferenceState)
caller.linfo !== nothing || return # don't add backedges to opauqe closures
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
if caller.stmt_edges[caller.currpc] === nothing
caller.stmt_edges[caller.currpc] = []
Expand Down
9 changes: 8 additions & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCache
end

mutable struct OptimizationState
linfo::MethodInstance
linfo::Union{MethodInstance, Nothing}
src::CodeInfo
stmt_info::Vector{Any}
mod::Module
Expand Down Expand Up @@ -311,6 +311,13 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
ftyp = argextype(farg, src, sptypes, slottypes)
end
end
# Give calls to OpaqueClosures zero cost. The plan is for these to be a single
# indirect call so have very little cost. On the other hand, there
# is enormous benefit to inlining these into a function where we can
# see the definition of the OpaqueClosure. Perhaps this should even be negative
if widenconst(ftyp) <: Core.OpaqueClosure
return 0
end
f = singleton_type(ftyp)
if isa(f, IntrinsicFunction)
iidx = Int(reinterpret(Int32, f::IntrinsicFunction)) + 1
Expand Down
64 changes: 55 additions & 9 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,33 @@ function normalize(@nospecialize(stmt), meta::Vector{Any})
return stmt
end

function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, nargs::Int, sv::OptimizationState)
function add_opaque_closure_argtypes!(argtypes, t)
dt = unwrap_unionall(t)
dt1 = unwrap_unionall(dt.parameters[1])
if isa(dt1, TypeVar) || isa(dt1.parameters[1], TypeVar)
push!(argtypes, Any)
else
TT = dt1.parameters[1]
if isa(TT, Union)
TT = tuplemerge(TT.a, TT.b)
end
for p in TT.parameters
push!(argtypes, rewrap_unionall(p, t))
end
end
end


function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, nargs::Int, sv::OptimizationState, slottypes=sv.slottypes, stmtinfo=sv.stmt_info)
# Go through and add an unreachable node after every
# Union{} call. Then reindex labels.
idx = 1
oldidx = 1
changemap = fill(0, length(code))
labelmap = coverage ? fill(0, length(code)) : changemap
prevloc = zero(eltype(ci.codelocs))
stmtinfo = sv.stmt_info
stmtinfo = copy(stmtinfo)
opaques = IRCode[]
while idx <= length(code)
codeloc = ci.codelocs[idx]
if coverage && codeloc != prevloc && codeloc != 0
Expand All @@ -57,7 +75,22 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
idx += 1
prevloc = codeloc
end
if code[idx] isa Expr && ci.ssavaluetypes[idx] === Union{}
stmt = code[idx]
if isexpr(stmt, :(=))
stmt = stmt.args[2]
end
ssat = ci.ssavaluetypes[idx]
if isa(ssat, PartialOpaque) && isexpr(stmt, :call)
ft = argextype(stmt.args[1], ci, sv.sptypes)
# Pre-convert any OpaqueClosure objects
if isa(ft, Const) && ft.val === Core._opaque_closure && isa(ssat.ci, OptimizationState)
opaque_ir = make_ir(ssat.ci.src, 0, ssat.ci)
push!(opaques, opaque_ir)
stmt.head = :new_opaque_closure
stmt.args[5] = OpaqueClosureIdx(length(opaques))
end
end
if stmt isa Expr && ssat === Union{}
if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
# insert unreachable in the same basic block after the current instruction (splitting it)
insert!(code, idx + 1, ReturnNode())
Expand Down Expand Up @@ -105,7 +138,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
cfg = compute_basic_blocks(code)
types = Any[]
stmts = InstructionStream(code, types, stmtinfo, ci.codelocs, flags)
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), sv.slottypes, meta, sv.sptypes)
ir = IRCode(stmts, cfg, collect(LineInfoNode, ci.linetable), slottypes, meta, sv.sptypes, opaques)
return ir
end

Expand All @@ -117,24 +150,37 @@ function slot2reg(ir::IRCode, ci::CodeInfo, nargs::Int, sv::OptimizationState)
return ir
end

function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
preserve_coverage = coverage_enabled(sv.mod)
ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, nargs, sv)
function compact_all!(ir::IRCode)
length(ir.stmts) == 0 && return ir
for i in 1:length(ir.opaques)
ir.opaques[i] = compact_all!(ir.opaques[i])
end
compact!(ir)
end

function make_ir(ci::CodeInfo, nargs::Int, sv::OptimizationState)
ir = convert_to_ircode(ci, copy_exprargs(ci.code), coverage_enabled(sv.mod), nargs, sv)
ir = slot2reg(ir, ci, nargs, sv)
ir
end

function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
ir = make_ir(ci, nargs, sv)
#@Base.show ("after_construct", ir)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
@timeit "compact 1" ir = compact!(ir)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
#@timeit "verify 2" verify_ir(ir)
ir = compact!(ir)
ir = compact_all!(ir)
#@Base.show ("before_sroa", ir)
@timeit "SROA" ir = getfield_elim_pass!(ir)
ir = opaque_closure_optim_pass!(ir)
#@Base.show ir.new_nodes
#@Base.show ("after_sroa", ir)
ir = adce_pass!(ir)
#@Base.show ("after_adce", ir)
@timeit "type lift" ir = type_lift_pass!(ir)
@timeit "compact 3" ir = compact!(ir)
@timeit "compact 3" ir = compact_all!(ir)
#@Base.show ir
if JLOptions().debug_level == 2
@timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable))
Expand Down
Loading

0 comments on commit 9cc49d3

Please sign in to comment.