From 4ca5f8463e0e1769306d80969915b6b0f0cd78e0 Mon Sep 17 00:00:00 2001 From: Rafael Fourquet Date: Mon, 9 Aug 2021 11:50:06 +0200 Subject: [PATCH] hijack: allow specifying custom include-like functions --- src/ReTest.jl | 52 +++++++++++++++++++--------- src/hijack.jl | 54 +++++++++++++++++++++++------- test/Hijack/test/include_static.jl | 4 ++- test/runtests.jl | 5 +-- 4 files changed, 84 insertions(+), 31 deletions(-) diff --git a/src/ReTest.jl b/src/ReTest.jl index 8027474..386bcb8 100644 --- a/src/ReTest.jl +++ b/src/ReTest.jl @@ -58,6 +58,7 @@ Base.@kwdef mutable struct Options verbose::Bool = false # annotated verbosity transient_verbose::Bool = false # verbosity for next run static_include::Bool = false # whether to execute include at `replace_ts` time + include_functions::Vector{Symbol} = [:include] # functions to treat like include end mutable struct TestsetExpr @@ -121,12 +122,14 @@ function extract_testsets(dest) end # replace unqualified `@testset` by TestsetExpr -function replace_ts(source, mod, x::Expr, parent; static_include::Bool) +function replace_ts(source, mod, x::Expr, parent; static_include::Bool, + include_functions::Vector{Symbol}) if x.head === :macrocall name = x.args[1] if name === Symbol("@testset") @assert x.args[2] isa LineNumberNode ts, hasbroken = parse_ts(x.args[2], mod, Tuple(x.args[3:end]), parent; + include_functions=include_functions, static_include=static_include) ts !== invalid && parent !== nothing && push!(parent.children, ts) ts, false # hasbroken counts only "proper" @test_broken, not recursive ones @@ -136,16 +139,17 @@ function replace_ts(source, mod, x::Expr, parent; static_include::Bool) # `@test` is generally called a lot, so it's probably worth it to skip # the containment test in this case x = macroexpand(mod, x, recursive=false) - replace_ts(source, mod, x, parent; static_include=static_include) + replace_ts(source, mod, x, parent; static_include=static_include, + include_functions=include_functions) else @goto default end - elseif x.head == :call && x.args[1] == :include + elseif x.head == :call && x.args[1] ∈ include_functions path = x.args[end] sourcepath = dirname(string(source.file)) x.args[end] = path isa AbstractString ? - joinpath(sourcepath, path) : - :(joinpath($sourcepath, $path)) + joinpath(sourcepath, path) : + :(joinpath($sourcepath, $path)) if static_include news = InlineTest.get_tests(mod).news newslen = length(news) @@ -166,7 +170,8 @@ function replace_ts(source, mod, x::Expr, parent; static_include::Bool) # below; it's currently not very important tsi.source, tsi.ts...) for tsi in newstmp)...) - replace_ts(source, mod, included_ts, parent; static_include=static_include) + replace_ts(source, mod, included_ts, parent; + static_include=static_include, include_functions=include_functions) else nothing, false end @@ -174,18 +179,21 @@ function replace_ts(source, mod, x::Expr, parent; static_include::Bool) x, false end else @label default - body_br = map(z -> replace_ts(source, mod, z, parent; static_include=static_include), + body_br = map(z -> replace_ts(source, mod, z, parent; static_include=static_include, + include_functions=include_functions), x.args) filter!(x -> first(x) !== invalid, body_br) Expr(x.head, first.(body_br)...), any(last.(body_br)) end end -replace_ts(source, mod, x, _1; static_include::Bool) = x, false +replace_ts(source, mod, x, _1; static_include::Bool, + include_functions) = x, false # create a TestsetExpr from @testset's args function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothing; - static_include::Bool=false) + static_include::Bool=false, include_functions::Vector{Symbol}=[:include]) + function tserror(msg) @error msg _file=String(source.file) _line=source.line _module=mod invalid, false @@ -195,7 +203,7 @@ function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothi return tserror("expected begin/end block or for loop as argument to @testset") local desc - options = Options() + options = Options(include_functions=include_functions) marks = Marks() if parent !== nothing append!(marks.hard, parent.marks.hard) # copy! not available in Julia 1.0 @@ -203,6 +211,7 @@ function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothi # if static_include was set in parent, it should have been forwarded also # through the parse_ts/replace_ts call chains: @assert static_include == parent.options.static_include + @assert include_functions === parent.options.include_functions end for arg in args[1:end-1] if arg isa String || Meta.isexpr(arg, :string) @@ -211,10 +220,22 @@ function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothi # TODO: support non-literal symbols? push!(marks.hard, arg.value) elseif Meta.isexpr(arg, :(=)) - arg.args[1] in fieldnames(Options) || - return tserror("unsupported @testset option") - # TODO: make that work with non-literals: - setfield!(options, arg.args[1], arg.args[2]) + optname = arg.args[1] + optname in fieldnames(Options) || + return tserror("unsupported @testset option: $optname") + if optname == :include_functions + @assert Meta.isexpr(arg.args[2], :vect) + if parent !== nothing + options.include_functions = Symbol[] # make it non-shared + end + for ifn in arg.args[2].args + @assert ifn isa QuoteNode + push!(options.include_functions, ifn.value) + end + else + # TODO: make that work with non-literals: + setfield!(options, optname, arg.args[2]) + end else return tserror("unsupported @testset") end @@ -253,7 +274,8 @@ function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothi ts = TestsetExpr(source, mod, desc, options, marks, loops, parent) ts.body, ts.hasbroken = replace_ts(source, mod, tsbody, ts; - static_include=options.static_include) + static_include=options.static_include, + include_functions=options.include_functions) ts, false # hasbroken counts only "proper" @test_broken, not recursive ones end diff --git a/src/hijack.jl b/src/hijack.jl index 0fd56c1..d4b9fe1 100644 --- a/src/hijack.jl +++ b/src/hijack.jl @@ -104,8 +104,8 @@ const loaded_testmodules = Dict{Module,Vector{Module}}() """ ReTest.hijack(source, [modname]; - parentmodule::Module=Main, lazy=false, [include::Symbol], - [revise::Bool]) + parentmodule::Module=Main, lazy=false, [revise::Bool], + [include::Symbol], include_functions=[:include]) Given test files defined in `source` using the `Test` package, try to load them by replacing `Test` with `ReTest`, wrapping them in a module `modname` @@ -183,6 +183,15 @@ so the best that can be done here is to "outline" such nested testsets; with `include=:static`, the subdmodules will get defined after `hijack` has returned (on the first call to `retest` thereafter), so won't be "processed". +#### `include_functions` keyword + +When the `include=:static` keyword argument is passed, it's possible to +tell `hijack` to apply the same treatment to other functions than +`include`, by passing a list a symbols to `include_functions`. +For example, if you defined a custom function `custom_include(x)` +which itself calls out to `include`, you can pass +`include_functions=[:custom_include]` to `hijack`. + #### `revise` keyword The `revise` keyword specifies whether `Revise` should be used to track @@ -201,7 +210,8 @@ function hijack end function hijack(path::AbstractString, modname=nothing; parentmodule::Module=Main, lazy=false, revise::Maybe{Bool}=nothing, - include::Maybe{Symbol}=nothing, testset::Bool=false) + include::Maybe{Symbol}=nothing, testset::Bool=false, + include_functions=[:include]) # do first, to error early if necessary Revise = get_revise(revise) @@ -213,6 +223,7 @@ function hijack(path::AbstractString, modname=nothing; parentmodule::Module=Main newmod = @eval parentmodule module $modname end populate_mod!(newmod, path; lazy=lazy, include=setinclude(include, testset), + include_functions=include_functions, Revise=Revise) newmod end @@ -234,12 +245,14 @@ const root_module = Ref{Symbol}() __init__() = root_module[] = gensym("MODULE") -function populate_mod!(mod::Module, path; lazy, Revise, include::Maybe{Symbol}=nothing) +function populate_mod!(mod::Module, path; lazy, Revise, include::Maybe{Symbol}=nothing, + include_functions) lazy ∈ (true, false, :brutal) || throw(ArgumentError("the `lazy` keyword must be `true`, `false` or `:brutal`")) files = Revise === nothing ? nothing : Dict(path => mod) - substitute!(x) = substitute_retest!(x, lazy, include, files) + substitute!(x) = substitute_retest!(x, lazy, include, files; + include_functions=include_functions) @eval mod begin using ReTest # for files which don't have `using Test` @@ -268,7 +281,8 @@ end function hijack(packagemod::Module, modname=nothing; parentmodule::Module=Main, lazy=false, revise::Maybe{Bool}=nothing, - include::Maybe{Symbol}=nothing, testset::Bool=false) + include::Maybe{Symbol}=nothing, testset::Bool=false, + include_functions=[:include]) packagepath = pathof(packagemod) packagepath === nothing && packagemod !== Base && throw(ArgumentError("$packagemod is not a package")) @@ -286,13 +300,15 @@ function hijack(packagemod::Module, modname=nothing; parentmodule::Module=Main, else path = joinpath(dirname(dirname(packagepath)), "test", "runtests.jl") hijack(path, modname, parentmodule=parentmodule, - lazy=lazy, testset=testset, include=include, revise=revise) + lazy=lazy, testset=testset, include=include, revise=revise, + include_functions=include_functions) end end function substitute_retest!(ex, lazy, include_::Maybe{Symbol}, files=nothing; - ishijack::Bool=true) - substitute!(x) = substitute_retest!(x, lazy, include_, files, ishijack=ishijack) + ishijack::Bool=true, include_functions::Vector{Symbol}=[:include]) + substitute!(x) = substitute_retest!(x, lazy, include_, files, ishijack=ishijack, + include_functions=include_functions) if Meta.isexpr(ex, :using) ishijack || return ex @@ -349,16 +365,28 @@ function substitute_retest!(ex, lazy, include_::Maybe{Symbol}, files=nothing; if body.head == :for body = body.args[end] end - includes = splice!(body.args, findall(body.args) do x - Meta.isexpr(x, :call) && x.args[1] == :include - end) + includes = splice!(body.args, + findall(body.args) do x + Meta.isexpr(x, :call) && x.args[1] ∈ include_functions + end) map!(substitute!, includes, includes) ex.head = :block newts = Expr(:macrocall, ex.args...) push!(empty!(ex.args), newts, includes...) - else # :static + elseif include_ == :static pos = ex.args[2] isa LineNumberNode ? 3 : 2 insert!(ex.args, pos, :(static_include=true)) + if !isempty(include_functions) + # enabled only for :static currently, would need to change ex.args + # to newts.args in order to make it work for :outline + ifn_expr = Expr(:vect) + for ifn in include_functions + push!(ifn_expr.args, QuoteNode(ifn)) + end + insert!(ex.args, length(ex.args)-1, :(include_functions=$ifn_expr)) + end + else + @assert false end end elseif ex isa Expr && ex.head ∈ (:block, :let, :for, :while, :if, :try) diff --git a/test/Hijack/test/include_static.jl b/test/Hijack/test/include_static.jl index 96edfb1..cc3a300 100644 --- a/test/Hijack/test/include_static.jl +++ b/test/Hijack/test/include_static.jl @@ -1,8 +1,10 @@ # include = :static using Hijack, Test +custom_include_function(f) = include(f) + @testset "include_static" begin @test true push!(Hijack.RUN, 1) - include("include_static_included1.jl") + custom_include_function("include_static_included1.jl") end diff --git a/test/runtests.jl b/test/runtests.jl index a17ffd7..9ca19ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -961,7 +961,7 @@ end @chapter TestsetErrors begin @test_logs ( :error, "expected begin/end block or for loop as argument to @testset") ( - :error, "unsupported @testset option") ( + :error, "unsupported @testset option: notexistingoption") ( :error, "unsupported @testset" ) ( :error, "expected begin/end block or for loop as argument to @testset") ( :error, "expected begin/end block or for loop as argument to @testset" @@ -1862,7 +1862,8 @@ end empty!(Hijack.RUN) @test_throws ErrorException ReTest.hijack("./Hijack/test/include_static.jl", :HijackInclude, include=:static, testset=true) @test_throws ErrorException ReTest.hijack("./Hijack/test/include_static.jl", :HijackInclude, include=:notvalid) - ReTest.hijack("./Hijack/test/include_static.jl", :HijackInclude, include=:static) + ReTest.hijack("./Hijack/test/include_static.jl", :HijackInclude, include=:static, + include_functions=[:include, :custom_include_function]) check(HijackInclude, dry=true, verbose=9, [], output=""" 1| include_static 2| include_static_included1 1