Skip to content

Commit

Permalink
enable toplevel optimization and eliminate some toplevel special casings
Browse files Browse the repository at this point in the history
Better to work with <JuliaLang/julia#42013>, but
I also added an hacky fallback that makes use of the existing method
definition pipeline.
  • Loading branch information
aviatesk committed Aug 26, 2021
1 parent 6591201 commit 7494818
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 49 deletions.
51 changes: 41 additions & 10 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -842,30 +842,32 @@ function report_text(text::AbstractString,
return JETToplevelResult(analyzer′, res, source; analyzer, jetconfigs...)
end

# we have to go on hacks; see `transform_abstract_global_symbols!` and `resolve_toplevel_symbols`
function analyze_toplevel!(analyzer::AbstractAnalyzer, src::CodeInfo)
# construct toplevel `MethodInstance`
mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ());
mi.uninferred = src
mi.specTypes = Tuple{}

transform_abstract_global_symbols!(analyzer, src)
mi.def = get_toplevelmod(analyzer)
mi.def = mod = get_toplevelmod(analyzer)
src = transform_abstract_global_symbols!(analyzer, src)
src = resolve_toplevel_symbols(mod, src)
mi.uninferred = src

result = InferenceResult(mi);
# toplevel frame doesn't need to be cached (and so it won't be optimized), nor should
# go through JET's code generation error check
frame = InferenceState(result, src, #=cached=# false, analyzer);
# toplevel frames don't really need to be cached, but still better to be optimized
# in order to get reasonable `LocalUndefVarErrorReport` and `UncaughtExceptionReport`
frame = InferenceState(result, src, #=cached=# true, analyzer);

return analyze_frame!(analyzer, frame)
end

# HACK this is an native hack to re-use `AbstractInterpreter`'s approximated slot types for
# HACK this is very naive hack to re-use `AbstractInterpreter`'s slot type approximation for
# assignments of abstract global variables, which are represented as toplevel symbols at this point;
# the idea is just to transform them into slots from symbols and use their approximated type
# on their assignment.
# the idea is just to transform them into slot from symbol and use their approximated type
# on their assignment (see `finish(::InferenceState, ::AbstractAnalyzer)`).
# NOTE that `transform_abstract_global_symbols!` will produce really invalid code for
# actual interpretation or execution, but all the statements won't be interpreted anymore
# by `ConcreteInterpreter` nor executed anyway since toplevel frames aren't cached
# by `ConcreteInterpreter` nor executed by the native compilation pipeline anyway
function transform_abstract_global_symbols!(analyzer::AbstractAnalyzer, src::CodeInfo)
nslots = length(src.slotnames)
abstrct_global_variables = Dict{Symbol,Int}()
Expand Down Expand Up @@ -901,6 +903,35 @@ function transform_abstract_global_symbols!(analyzer::AbstractAnalyzer, src::Cod
end

set_global_slots!(analyzer, Dict(idx => slotname for (slotname, idx) in abstrct_global_variables))

return src
end

# resolve toplevel symbols (and other expressions like `:foreigncall`)
# so that the returned `CodeInfo` is eligible for abstractintepret and optimization
@static if VERSION v"1.8.0-DEV.420"
function resolve_toplevel_symbols(mod::Module, src::CodeInfo)
newsrc = copy(src)
@ccall jl_resolve_globals_in_ir(newsrc.code::Any, mod::Any, svec()::Any, 1::Any)::Cvoid
return newsrc
end
else
# HACK before https://github.com/JuliaLang/julia/pull/42013, we need to go through
# the method definition pipeline to get the effect of `jl_resolve_globals_in_ir`
function resolve_toplevel_symbols(mod::Module, src::CodeInfo)
sig = Core.svec(
svec(typeof(__toplevelf__)),
svec(),
QuoteNode(LineNumberNode(@__LINE__, @__FILE__)))
# branching on https://github.com/JuliaLang/julia/pull/41137
method = (@static if isdefined(Core.Compiler, :OverlayMethodTable)
ccall(:jl_method_def, Any, (Any, Ptr{Cvoid}, Any, Any), sig, C_NULL, src, mod)
else
ccall(:jl_method_def, Any, (Any, Any, Any), sig, src, mod)
end)::Method
return CC.uncompressed_ir(method)
end
function __toplevelf__ end
end

# TODO `analyze_builtin!` ?
Expand Down
5 changes: 1 addition & 4 deletions src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,6 @@ function CC.abstract_eval_special_value(analyzer::AbstractAnalyzer, @nospecializ
# if it's really not defined, the error will be generated later anyway
e = GlobalRef(get_toplevelmod(analyzer), get_slotname(sv, e))
end
elseif isa(e, Symbol)
# (already concretized) toplevel global symbols
e = GlobalRef(get_toplevelmod(analyzer), e)
end
end

Expand Down Expand Up @@ -749,7 +746,7 @@ function is_constant_declared(name::Symbol, sv::InferenceState)
return any(sv.src.code) do @nospecialize(x)
if @isexpr(x, :const)
arg = first(x.args)
# `transform_global_symbols!` replaces all the global symbols in this toplevel frame with `Slot`s
# `transform_abstract_global_symbols!` replaces all the global symbols in this toplevel frame with `Slot`s
if isa(arg, Slot)
return get_slotname(sv, arg) === name
end
Expand Down
3 changes: 1 addition & 2 deletions src/analyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,7 @@ function maybe_initialize_caches!(analyzer::AbstractAnalyzer)
end

# check if we're in a toplevel module
@inline istoplevel(sv::InferenceState) = istoplevel(sv.linfo)
@inline istoplevel(::OptimizationState) = false # optimization never happen for top-level code
@inline istoplevel(sv::State) = istoplevel(sv.linfo)
@inline istoplevel(linfo::MethodInstance) = isa(linfo.def, Module)

is_global_slot(analyzer::AbstractAnalyzer, slot::Int) = slot in keys(get_global_slots(analyzer))
Expand Down
10 changes: 1 addition & 9 deletions src/locinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,7 @@ function _get_sig_type((sv, _)::StateAtPC, arg::Argument)
return Any[sig, typ], typ
end
_get_sig_type(_::StateAtPC, gr::GlobalRef) = Any[string(gr.mod, '.', gr.name)], nothing
function _get_sig_type(s::StateAtPC, name::Symbol)
sv = first(s)
if istoplevel(sv)
# this is concrete global variable, form the global reference
return _get_sig_type(s, GlobalRef(sv.linfo.def, name))
else
return Any[repr(name; context = :compact => true)], nothing
end
end
_get_sig_type(_::StateAtPC, name::Symbol) = Any[repr(name; context = :compact => true)], nothing
function _get_sig_type(s::StateAtPC, gotoifnot::GotoIfNot)
sig = Any[string("goto %", gotoifnot.dest, " if not "), _get_sig(s, gotoifnot.cond)...]
return sig, nothing
Expand Down
22 changes: 2 additions & 20 deletions src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ function (::SoundBasicPass)(::Type{UncaughtExceptionReport}, analyzer::AbstractA
throw_locs = get_throw_locs(analyzer)
throw_calls = Tuple{Int,Expr}[]
for (pc, stmt) in enumerate(stmts)
is_throw_call_expr(analyzer, frame, stmt) || continue
isa(stmt, Expr) || continue
is_throw_call(stmt) || continue
# if this `throw` is already reported, don't duplciate
linetable[codelocs[pc]]::LineInfoNode in throw_locs && continue
push!(throw_calls, (pc, stmt))
Expand All @@ -303,22 +304,3 @@ function (::SoundBasicPass)(::Type{UncaughtExceptionReport}, analyzer::AbstractA
empty!(get_uncaught_exceptions(analyzer))
end
end

# basically same as `is_throw_call`, but also toplevel module handling added
function is_throw_call_expr(analyzer::AbstractAnalyzer, frame::InferenceState, @nospecialize(e))
if isa(e, Expr)
if e.head === :call
f = e.args[1]
if istoplevel(frame) && isa(f, Symbol)
f = GlobalRef(get_toplevelmod(analyzer), f)
end
if isa(f, GlobalRef)
ff = CC.abstract_eval_global(f.mod, f.name)
if isa(ff, Const) && ff.val === Core.throw
return true
end
end
end
end
return false
end
8 changes: 4 additions & 4 deletions test/test_abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,7 @@ end
@test isempty(get_reports(analyzer))
end

# with the current approach, local undefined variables in toplevel frame can't be found
# since we don't cache toplevel frame and thus it won't be optimized
let
let # should work for top-level analysis
res = @analyze_toplevel begin
foo = let
if rand(Bool)
Expand All @@ -146,7 +144,9 @@ end
end
end
end
@test_broken !isempty(res.inference_error_reports)
@test length(res.inference_error_reports) === 1 &&
first(res.inference_error_reports) isa LocalUndefVarErrorReport &&
first(res.inference_error_reports).name === :bar
end
end

Expand Down

0 comments on commit 7494818

Please sign in to comment.