Skip to content

Commit

Permalink
wip: fix #203, allow configurations to ignore reports
Browse files Browse the repository at this point in the history
This PR set-ups `ignored_patterns` interface to ignore certain patterns
of reports, which can be configured by users extensively.

`ignored_patterns` is supposed to be an iterator of predicate function
and if any of given `ignored_patterns` matches a report in question,
the report will just be ignored (and nor cached).
The predicate function will be called _before_ a report is actually
constructed for the possible cut-off of the computational cost to
construct the report. Thus its signature is: 
`(T::Type{<:InferenceErrorReport}, interp::JETInterpreter, 
sv::InferenceState, spec_args::Tuple)`, where:
- `T` specifies the kind of report
- `interp` gives the context of whole analysis
- `sv::InferenceState` gives the local context of analysis
- and `spec_args` will be report-specific arguments

To inject predicate checks, now each report is constructed via 
`@report!`
macro, which first checks if the report matches any of 
`ignored_patterns`,
and if not, constructs the report and pushes it to `JETInterpreter` as
the previous `report!` function did.

This PR also defines `DEFAULT_IGNORED_PATTERNS`, namely 
`ignored_patterns`
applied by default. Currently it consists of 
`ignore_corecompiler_undefglobal`,
which ignores undefined global names in `Core.Compiler`, and it should
have some positive effects, like reduce false positive error reports
involved with `Base.return_types` and its family in general, improve
analysis performance for JET's self-profiling, etc.

The support of user-predicate functions specified via .JET configuration
file is somewhat tricky, but currently this PR uses 
RuntimeGeneratedFunctions.jl
and it seems to work. I'm not sure if `@nospecialize` notations works
correctly in `RuntimeGeneratedFunction`s, so some insights or benchmarks
on it will be very welcomed.
  • Loading branch information
aviatesk committed May 27, 2021
1 parent 8aeff9a commit 16a6ef4
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 48 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ LoweredCodeUtils = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"

[compat]
JuliaInterpreter = "0.8.16"
Expand Down
25 changes: 23 additions & 2 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ import JuliaInterpreter:
moduleof,
@lookup

import MacroTools: MacroTools, @capture, normalise, striplines
import MacroTools: MacroTools, @capture, normalise, striplines, shortdef

using RuntimeGeneratedFunctions

using InteractiveUtils

Expand All @@ -185,6 +187,8 @@ const INIT_HOOKS = Function[]
push_inithook!(f) = push!(INIT_HOOKS, f)
__init__() = foreach(@nospecialize(f)->f(), INIT_HOOKS)

RuntimeGeneratedFunctions.init(@__MODULE__)

# compat
# ------

Expand All @@ -202,6 +206,12 @@ else
ignorelimited(@nospecialize(x)) = x
end

@static if isdefined(Base, Symbol("@aggressive_constprop"))
import Base: @aggressive_constprop
else
macro aggressive_constprop(x) esc(x) end # not available
end

# macros
# ------

Expand Down Expand Up @@ -615,6 +625,11 @@ function process_config_dict!(config_dict)
@assert isa(concretization_patterns, Vector{String}) "`concretization_patterns` should be array of string of Julia expression"
config_dict["concretization_patterns"] = trymetaparse.(concretization_patterns)
end
ignored_patterns = get(config_dict, "ignored_patterns", nothing)
if !isnothing(ignored_patterns)
@assert isa(ignored_patterns, Vector{String}) "`ignored_patterns` should be array of string of Julia expression"
config_dict["ignored_patterns"] = rgf.(trymetaparse.(ignored_patterns))
end
toplevel_logger = get(config_dict, "toplevel_logger", nothing)
if !isnothing(toplevel_logger)
@assert isa(toplevel_logger, String) "`toplevel_logger` should be string of Julia code"
Expand All @@ -634,6 +649,12 @@ function trymetaparse(s)
return ret
end

function rgf(x)
# normalize so that `cache_key` doesn't have the whitespace/newline sensitivity of this definition
ex = striplines(shortdef(x))
@RuntimeGeneratedFunction(@__MODULE__, ex)
end

function kwargs(dict)
ns = (Symbol.(keys(dict))...,)
vs = (collect(values(dict))...,)
Expand Down Expand Up @@ -848,7 +869,7 @@ function may_report_get_staged!(interp::JETInterpreter, mi::MethodInstance)
ccall(:jl_code_for_staged, Any, (Any,), mi)
catch err
# if user code throws error, wrap and report it
report!(interp, GeneratorErrorReport(interp, mi, err))
@report!(GeneratorErrorReport(interp, mi, err))
end
end

Expand Down
18 changes: 9 additions & 9 deletions src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function CC.abstract_call_gf_by_type(interp::JETInterpreter, @nospecialize(f),
end
if isa(info, MethodMatchInfo)
if is_empty_match(info)
report!(interp, NoMethodErrorReport(interp, sv, atype))
@report!(NoMethodErrorReport(interp, sv, atype))
end
elseif isa(info, UnionSplitInfo)
# check each match for union-split signature
Expand All @@ -73,7 +73,7 @@ function CC.abstract_call_gf_by_type(interp::JETInterpreter, @nospecialize(f),
end

if !isnothing(ts)
report!(interp, NoMethodErrorReport(interp, sv, ts))
@report!(NoMethodErrorReport(interp, sv, ts))
end
end

Expand Down Expand Up @@ -306,7 +306,7 @@ function CC.abstract_invoke(interp::JETInterpreter, argtypes::Vector{Any}, sv::I
# if the error type (`Bottom`) is propagated from the `invoke`d call, the error has
# already been reported within `typeinf_edge`, so ignore that case
if !isa(ret.info, InvokeCallInfo)
report!(interp, InvalidInvokeErrorReport(interp, sv, argtypes))
@report!(InvalidInvokeErrorReport(interp, sv, argtypes))
end
end

Expand Down Expand Up @@ -386,7 +386,7 @@ function CC.abstract_eval_special_value(interp::JETInterpreter, @nospecialize(e)
end
else
# report access to undefined global variable
report!(interp, GlobalUndefVarErrorReport(interp, sv, mod, name))
@report!(GlobalUndefVarErrorReport(interp, sv, mod, name))

# `ret` at this point should be annotated as `Any` by `NativeInterpreter`, and
# we just pass it as is to collect as much error points as possible within this
Expand Down Expand Up @@ -433,11 +433,11 @@ function CC.abstract_eval_value(interp::JETInterpreter, @nospecialize(e), vtypes
end
end
if !isempty(ts)
report!(interp, NonBooleanCondErrorReport(interp, sv, ts))
@report!(NonBooleanCondErrorReport(interp, sv, ts))
end
else
if typeintersect(Bool, t) !== Bool
report!(interp, NonBooleanCondErrorReport(interp, sv, t))
@report!(NonBooleanCondErrorReport(interp, sv, t))
ret = Bottom
end
end
Expand Down Expand Up @@ -555,7 +555,7 @@ function set_abstract_global!(interp, mod, name, @nospecialize(t), isnd, sv)
prev_t = val.t
if val.iscd && widenconst(prev_t) !== widenconst(t)
warn_invalid_const_global!(name)
report!(interp, InvalidConstantRedefinition(interp, sv, mod, name, widenconst(prev_t), widenconst(t)))
@report!(InvalidConstantRedefinition(interp, sv, mod, name, widenconst(prev_t), widenconst(t)))
return
end
prev_agv = val
Expand All @@ -565,7 +565,7 @@ function set_abstract_global!(interp, mod, name, @nospecialize(t), isnd, sv)
invalid = prev_t !== widenconst(t)
if invalid || !isa(t, Const)
warn_invalid_const_global!(name)
invalid && report!(interp, InvalidConstantRedefinition(interp, sv, mod, name, prev_t, widenconst(t)))
invalid && @report!(InvalidConstantRedefinition(interp, sv, mod, name, prev_t, widenconst(t)))
return
end
# otherwise, we can just redefine this constant, and Julia will warn it
Expand All @@ -580,7 +580,7 @@ function set_abstract_global!(interp, mod, name, @nospecialize(t), isnd, sv)
# if this constant declaration is invalid, just report it and bail out
if iscd && !isnew
warn_invalid_const_global!(name)
report!(interp, InvalidConstantDeclaration(interp, sv, mod, name))
@report!(InvalidConstantDeclaration(interp, sv, mod, name))
return
end

Expand Down
94 changes: 74 additions & 20 deletions src/abstractinterpreterinterface.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# configurations
# --------------

# TODO more configurations, e.g. ignore user-specified modules and such
"""
Configurations for JET analysis.
These configurations will be active for all the entries.
Expand Down Expand Up @@ -56,23 +55,87 @@ These configurations will be active for all the entries.
│││││││└───────────────
Dict{Any, Int64}
```
---
$(""#=---
- `ignore_native_remarks::Bool = true` \\
If `true`, JET won't construct nor cache reports of "native remarks", which may speed up analysis time.
"Native remarks" are information that Julia's native compiler emits about how type inference routine goes,
and those remarks are less interesting in term of "error checking", so JET ignores them by default.
and those remarks are less interesting in term of "error checking", so JET ignores them by default.=#)
"""
struct JETAnalysisParams
struct JETAnalysisParams{X}
ignored_patterns::X
strict_condition_check::Bool
ignore_native_remarks::Bool
@jetconfigurable JETAnalysisParams(; strict_condition_check::Bool = false,
ignore_native_remarks::Bool = true,
) =
return new(strict_condition_check,
ignore_native_remarks,
)
# ignore_native_remarks::Bool
# `@aggressive_constprop` here makes sure `enable_default_ignored_patterns` to be propagated as constant
@aggressive_constprop @jetconfigurable function JETAnalysisParams(;
ignored_patterns::T = (),
enable_default_ignored_patterns::Bool = true,
strict_condition_check::Bool = false,
# ignore_native_remarks::Bool = true,
) where T
if enable_default_ignored_patterns
ignored_patterns = tuple(ignored_patterns..., DEFAULT_IGNORED_PATTERNS...)
X = Core.Typeof(ignored_patterns)
else
X = T
end
return new{X}(ignored_patterns,
strict_condition_check,
# ignore_native_remarks,
)
end
end

macro report!(reportcall, target = QuoteNode(:reports))
@assert @isexpr(reportcall, :call)
@assert length(reportcall.args) 3 # (T<:InferenceErrorReport)(interp, sv, specs...)
T, interp, sv, specs... = reportcall.args
spec_names = [gensym() for _ in 1:length(specs)]
destruct_lhs = Expr(:tuple, :T, :interp, :sv)
append!(destruct_lhs.args, spec_names)
destruct_rhs = Expr(:tuple, esc(T), esc(interp), esc(sv), esc.(specs)...)
let_destruct = Expr(:(=), destruct_lhs, destruct_rhs)
constructor = Expr(:call, :T, :interp, :sv)
append!(constructor.args, spec_names)
return :(let $let_destruct
local skip = false
for predicate in $JETAnalysisParams(interp).ignored_patterns
if predicate(T, interp, sv, ($(spec_names...),))
skip = true
break
end
end
if !skip
target = getproperty(interp, $target)
push!(target, $constructor)
end
end)
end

"""
ignore_corecompiler_undefglobal
Ignores error points reported at an undefined global binding in `Core.Compiler`, as far as
the undefined name is defined in the corresponding `Base` module.
`Core.Compiler` reuses minimum amount of `Base` definitions and there're some of missing
definitions, but they usually don't matter and `Core.Compiler`'s basic functionality is
battle-tested and validated exhausively by its test suite and real-world usages !
"""
function ignore_corecompiler_undefglobal(@nospecialize(T), interp, sv, @nospecialize(spec_args))
T === GlobalUndefVarErrorReport || return false
mod, name = spec_args::Tuple{Module,Symbol}
if mod === Core.Compiler
return isdefined(Base, name)
elseif mod === Core.Compiler.Sort
return isdefined(Base.Sort, name)
else
return false
end
end

const DEFAULT_IGNORED_PATTERNS = tuple(
ignore_corecompiler_undefglobal,
)

"""
Configurations for Julia's native type inference routine.
These configurations will be active for all the entries.
Expand Down Expand Up @@ -431,15 +494,6 @@ JETAnalysisParams(interp::JETInterpreter) = interp.analysis_params

JETLogger(interp::JETInterpreter) = interp.logger

# TODO do report filtering or something configured by `JETAnalysisParams(interp)`
function report!(interp::JETInterpreter, report::InferenceErrorReport)
push!(interp.reports, report)
end

function stash_uncaught_exception!(interp::JETInterpreter, report::UncaughtExceptionReport)
push!(interp.uncaught_exceptions, report)
end

# check if we're in a toplevel module
@inline istoplevel(interp::JETInterpreter, sv::InferenceState) = istoplevel(interp, sv.linfo)
@inline istoplevel(interp::JETInterpreter, linfo::MethodInstance) = interp.toplevelmod === linfo.def
Expand Down
4 changes: 2 additions & 2 deletions src/legacy/abstractinterpretation
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), arg
#=== abstract_call_gf_by_type patch point 1-1 start ===#
if ts !== nothing
# report `NoMethodErrorReport` for each union-split signature
$report!(interp, $NoMethodErrorReport(interp, sv, ts))
$(macroexpand(JET, :(@report!($NoMethodErrorReport(interp, sv, ts)))))
end
#=== abstract_call_gf_by_type patch point 1-1 end ===#
info = UnionSplitInfo(infos)
Expand All @@ -104,7 +104,7 @@ function abstract_call_gf_by_type(interp::$JETInterpreter, @nospecialize(f), arg
#=== abstract_call_gf_by_type patch point 1-2 start ===#
if $is_empty_match(info)
# report `NoMethodErrorReport` for this call signature
$report!(interp, $NoMethodErrorReport(interp, sv, atype))
$(macroexpand(JET, :(@report!($NoMethodErrorReport(interp, sv, atype)))))
end
#=== abstract_call_gf_by_type patch point 1-2 end ===#
applicable = matches.matches
Expand Down
4 changes: 2 additions & 2 deletions src/reports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ end
function restore_cached_report!(cache::InferenceErrorReportCache, interp)
report = restore_cached_report(cache)
if isa(report, UncaughtExceptionReport)
stash_uncaught_exception!(interp, report)
push!(interp.uncaught_exceptions, report)
else
report!(interp, report)
push!(interp.reports, report)
end
return report
end
Expand Down
20 changes: 10 additions & 10 deletions src/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function CC.builtin_tfunction(interp::JETInterpreter, @nospecialize(f), argtypes
if isa(a, Const)
v = a.val
if isa(v, UndefKeywordError)
report!(interp, UndefKeywordErrorReport(interp, sv, v, get_lin(sv)))
@report!(UndefKeywordErrorReport(interp, sv, v, get_lin(sv)))
end
end
end
Expand All @@ -40,13 +40,13 @@ function CC.builtin_tfunction(interp::JETInterpreter, @nospecialize(f), argtypes
# TODO; `ret` should be `Any` here, add report pass here (for performance linting)
else
# report access to undefined global variable
report!(interp, GlobalUndefVarErrorReport(interp, sv, mod, name))
@report!(GlobalUndefVarErrorReport(interp, sv, mod, name))
# return Bottom
end
elseif ret === Bottom
# general case when an error is detected by the native `getfield_tfunc`
typ = widenconst(obj)
report!(interp, NoFieldErrorReport(interp, sv, typ, name))
@report!(NoFieldErrorReport(interp, sv, typ, name))
return ret
end
end
Expand Down Expand Up @@ -75,7 +75,7 @@ function CC.builtin_tfunction(interp::JETInterpreter, @nospecialize(f), argtypes
t = widenconst(a)
if t <: Base.BitSigned64 || t <: Base.BitUnsigned64
if isa(a, Const) && a.val === zero(t)
report!(interp, DivideErrorReport(interp, sv))
@report!(DivideErrorReport(interp, sv))
return Bottom
end
end
Expand All @@ -85,7 +85,7 @@ function CC.builtin_tfunction(interp::JETInterpreter, @nospecialize(f), argtypes
# XXX: for general case, JET just relies on the (maybe too persmissive) return type
# from native tfuncs to report invalid builtin calls and probably there're lots of
# false negatives
report!(interp, InvalidBuiltinCallErrorReport(interp, sv, argtypes))
@report!(InvalidBuiltinCallErrorReport(interp, sv, argtypes))
end

return ret
Expand All @@ -108,11 +108,11 @@ end
function CC.return_type_tfunc(interp::JETInterpreter, argtypes::Vector{Any}, sv::InferenceState)
if length(argtypes) 3
# invalid argument number, let's report and return error result (i.e. `Bottom`)
report!(interp, NoMethodErrorReport(interp,
sv,
# this is not necessary to be computed correctly, though
argtypes_to_type(argtypes),
))
@report!(NoMethodErrorReport(interp,
sv,
# this is not necessary to be computed correctly, though
argtypes_to_type(argtypes),
))
@static if isdefined(CC, :ReturnTypeCallInfo)
return CallMeta(Bottom, nothing)
else
Expand Down
6 changes: 3 additions & 3 deletions src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function CC._typeinf(interp::JETInterpreter, frame::InferenceState)
# `:(unreachable)` are introduced by `optimize`
for (idx, stmt) in enumerate(stmts)
if isa(stmt, Expr) && stmt.head === :throw_undef_if_not
sym, _ = stmt.args
sym::Symbol, _ = stmt.args

# slots in toplevel frame may be a abstract global slot
istoplevel(interp, frame) && is_global_slot(interp, sym) && continue
Expand All @@ -103,7 +103,7 @@ function CC._typeinf(interp::JETInterpreter, frame::InferenceState)
# the optimization so far has found this statement is never "reachable";
# JET reports it since it will invoke undef var error at runtime, or will just
# be dead code otherwise
report!(interp, LocalUndefVarErrorReport(interp, frame, sym, idx))
@report!(LocalUndefVarErrorReport(interp, frame, sym, idx))
# else
# by excluding this pass, JET accepts some false negatives (i.e. don't report
# those that may actually happen on actual execution)
Expand Down Expand Up @@ -142,7 +142,7 @@ function CC._typeinf(interp::JETInterpreter, frame::InferenceState)
push!(throw_calls, stmt)
end
if !isempty(throw_calls)
stash_uncaught_exception!(interp, UncaughtExceptionReport(interp, frame, throw_calls))
@report!(UncaughtExceptionReport(interp, frame, throw_calls), :uncaught_exceptions)
end
else
# the non-`Bottom` result here may mean `throw` calls from the children frames
Expand Down

0 comments on commit 16a6ef4

Please sign in to comment.