diff --git a/src/Traceur.jl b/src/Traceur.jl index 8534388..cc4f224 100644 --- a/src/Traceur.jl +++ b/src/Traceur.jl @@ -6,11 +6,12 @@ using MacroTools import Core.MethodInstance -export @trace, @trace_static +export @trace, @trace_static, @should_not_warn, @check include("util.jl") include("analysis.jl") include("trace.jl") # include("trace_static.jl") +include("check.jl") end # module diff --git a/src/analysis.jl b/src/analysis.jl index 33c77fd..ce4bc81 100644 --- a/src/analysis.jl +++ b/src/analysis.jl @@ -70,9 +70,11 @@ struct Warning call::Call line::Int message::String + stack::Vector{Call} end Warning(call, message) = Warning(call, -1, message) +Warning(call, line, message) = Warning(call, line, message, Call[]) function warning_printer() (w) -> begin diff --git a/src/check.jl b/src/check.jl new file mode 100644 index 0000000..ff0f85f --- /dev/null +++ b/src/check.jl @@ -0,0 +1,40 @@ +const should_not_warn = Set{Function}() + +# this is only a macro to avoid parens around function definitions: +# @should_not_warn +# function foo(x) +# ... +# end +macro should_not_warn(expr) + quote + fun = $(esc(expr)) + push!(should_not_warn, fun) + fun + end +end + +""" + check(f::Function) + +Run Traceur on f, and throw an error if any warnings occur inside functions tagged with @should_not_warn. +""" +function check(f) + failed = false + wp = warning_printer() + result = trace(f) do warning + ix = findfirst((call) -> call.f in should_not_warn, warning.stack) + if ix != nothing + tagged_function = warning.stack[ix].f + message = "$(warning.message) (called from $(tagged_function))" + warning = Warning(warning.call, warning.line, message, warning.stack) + wp(warning) + failed = true + end + end + @assert !failed "One or more warnings occured inside functions tagged with @should_not_warn" + result +end + +macro check(expr) + :(check(() -> $(esc(expr)))) +end diff --git a/src/trace.jl b/src/trace.jl index 823b709..f29b310 100644 --- a/src/trace.jl +++ b/src/trace.jl @@ -4,25 +4,35 @@ Cassette.@context TraceurCtx struct Trace seen::Set + stack::Vector{Call} warn end -Trace(w) = Trace(Set(), w) +Trace(w) = Trace(Set(), Vector{Call}(), w) isprimitive(f) = f isa Core.Builtin || f isa Core.IntrinsicFunction const ignored_methods = Set([@which((1,2)[1])]) const ignored_functions = Set([getproperty, setproperty!]) -function Cassette.posthook(ctx::TraceurCtx, out, f, args...) - C, T = DynamicCall(f, args...), typeof.((f, args)) +function Cassette.prehook(ctx::TraceurCtx, f, args...) tra = ctx.metadata - (f ∈ ignored_functions || T ∈ tra.seen || isprimitive(f) || - method(C) ∈ ignored_methods || method(C).module ∈ (Core, Core.Compiler)) && return nothing + C = DynamicCall(f, args...) + push!(tra.stack, C) + nothing +end - push!(tra.seen, T) - analyse((a...) -> tra.warn(Warning(a...)), C) - return nothing +function Cassette.posthook(ctx::TraceurCtx, out, f, args...) + tra = ctx.metadata + C = tra.stack[end] + T = typeof.((f, args)) + if !(f ∈ ignored_functions || T ∈ tra.seen || isprimitive(f) || + method(C) ∈ ignored_methods || method(C).module ∈ (Core, Core.Compiler)) + push!(tra.seen, T) + analyse((a...) -> tra.warn(Warning(a..., copy(tra.stack))), C) + end + pop!(tra.stack) + nothing end trace(w, f) = Cassette.overdub(TraceurCtx(metadata=Trace(w)), f) diff --git a/test/runtests.jl b/test/runtests.jl index 0157b9e..9b0abf6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,12 @@ function test(warnings) @test isempty(ws) end +x = 1 + +my_add(y) = x + y + +@should_not_warn my_stable_add(y) = my_add(y) + @testset "Traceur" begin @testset "Dynamic" begin test(Traceur.warnings) @@ -61,4 +67,7 @@ end # end # @test_nowarn @trace naive_sum(1.0) # @test_nowarn @trace_static naive_sum(1.0) + + @test_nowarn @check my_add(1) + @test_throws AssertionError @check my_stable_add(1) end