Skip to content

Commit

Permalink
hijack: allow specifying custom include-like functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rfourquet committed Aug 9, 2021
1 parent 697e528 commit 4ca5f84
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 31 deletions.
52 changes: 37 additions & 15 deletions src/ReTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Base.@kwdef mutable struct Options
verbose::Bool = false # annotated verbosity
transient_verbose::Bool = false # verbosity for next run
static_include::Bool = false # whether to execute include at `replace_ts` time
include_functions::Vector{Symbol} = [:include] # functions to treat like include
end

mutable struct TestsetExpr
Expand Down Expand Up @@ -121,12 +122,14 @@ function extract_testsets(dest)
end

# replace unqualified `@testset` by TestsetExpr
function replace_ts(source, mod, x::Expr, parent; static_include::Bool)
function replace_ts(source, mod, x::Expr, parent; static_include::Bool,
include_functions::Vector{Symbol})
if x.head === :macrocall
name = x.args[1]
if name === Symbol("@testset")
@assert x.args[2] isa LineNumberNode
ts, hasbroken = parse_ts(x.args[2], mod, Tuple(x.args[3:end]), parent;
include_functions=include_functions,
static_include=static_include)
ts !== invalid && parent !== nothing && push!(parent.children, ts)
ts, false # hasbroken counts only "proper" @test_broken, not recursive ones
Expand All @@ -136,16 +139,17 @@ function replace_ts(source, mod, x::Expr, parent; static_include::Bool)
# `@test` is generally called a lot, so it's probably worth it to skip
# the containment test in this case
x = macroexpand(mod, x, recursive=false)
replace_ts(source, mod, x, parent; static_include=static_include)
replace_ts(source, mod, x, parent; static_include=static_include,
include_functions=include_functions)
else
@goto default
end
elseif x.head == :call && x.args[1] == :include
elseif x.head == :call && x.args[1] include_functions
path = x.args[end]
sourcepath = dirname(string(source.file))
x.args[end] = path isa AbstractString ?
joinpath(sourcepath, path) :
:(joinpath($sourcepath, $path))
joinpath(sourcepath, path) :
:(joinpath($sourcepath, $path))
if static_include
news = InlineTest.get_tests(mod).news
newslen = length(news)
Expand All @@ -166,26 +170,30 @@ function replace_ts(source, mod, x::Expr, parent; static_include::Bool)
# below; it's currently not very important
tsi.source, tsi.ts...)
for tsi in newstmp)...)
replace_ts(source, mod, included_ts, parent; static_include=static_include)
replace_ts(source, mod, included_ts, parent;
static_include=static_include, include_functions=include_functions)
else
nothing, false
end
else
x, false
end
else @label default
body_br = map(z -> replace_ts(source, mod, z, parent; static_include=static_include),
body_br = map(z -> replace_ts(source, mod, z, parent; static_include=static_include,
include_functions=include_functions),
x.args)
filter!(x -> first(x) !== invalid, body_br)
Expr(x.head, first.(body_br)...), any(last.(body_br))
end
end

replace_ts(source, mod, x, _1; static_include::Bool) = x, false
replace_ts(source, mod, x, _1; static_include::Bool,
include_functions) = x, false

# create a TestsetExpr from @testset's args
function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothing;
static_include::Bool=false)
static_include::Bool=false, include_functions::Vector{Symbol}=[:include])

function tserror(msg)
@error msg _file=String(source.file) _line=source.line _module=mod
invalid, false
Expand All @@ -195,14 +203,15 @@ function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothi
return tserror("expected begin/end block or for loop as argument to @testset")

local desc
options = Options()
options = Options(include_functions=include_functions)
marks = Marks()
if parent !== nothing
append!(marks.hard, parent.marks.hard) # copy! not available in Julia 1.0
options.static_include = parent.options.static_include
# if static_include was set in parent, it should have been forwarded also
# through the parse_ts/replace_ts call chains:
@assert static_include == parent.options.static_include
@assert include_functions === parent.options.include_functions
end
for arg in args[1:end-1]
if arg isa String || Meta.isexpr(arg, :string)
Expand All @@ -211,10 +220,22 @@ function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothi
# TODO: support non-literal symbols?
push!(marks.hard, arg.value)
elseif Meta.isexpr(arg, :(=))
arg.args[1] in fieldnames(Options) ||
return tserror("unsupported @testset option")
# TODO: make that work with non-literals:
setfield!(options, arg.args[1], arg.args[2])
optname = arg.args[1]
optname in fieldnames(Options) ||
return tserror("unsupported @testset option: $optname")
if optname == :include_functions
@assert Meta.isexpr(arg.args[2], :vect)
if parent !== nothing
options.include_functions = Symbol[] # make it non-shared
end
for ifn in arg.args[2].args
@assert ifn isa QuoteNode
push!(options.include_functions, ifn.value)
end
else
# TODO: make that work with non-literals:
setfield!(options, optname, arg.args[2])
end
else
return tserror("unsupported @testset")
end
Expand Down Expand Up @@ -253,7 +274,8 @@ function parse_ts(source::LineNumberNode, mod::Module, args::Tuple, parent=nothi

ts = TestsetExpr(source, mod, desc, options, marks, loops, parent)
ts.body, ts.hasbroken = replace_ts(source, mod, tsbody, ts;
static_include=options.static_include)
static_include=options.static_include,
include_functions=options.include_functions)
ts, false # hasbroken counts only "proper" @test_broken, not recursive ones
end

Expand Down
54 changes: 41 additions & 13 deletions src/hijack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ const loaded_testmodules = Dict{Module,Vector{Module}}()

"""
ReTest.hijack(source, [modname];
parentmodule::Module=Main, lazy=false, [include::Symbol],
[revise::Bool])
parentmodule::Module=Main, lazy=false, [revise::Bool],
[include::Symbol], include_functions=[:include])
Given test files defined in `source` using the `Test` package, try to load
them by replacing `Test` with `ReTest`, wrapping them in a module `modname`
Expand Down Expand Up @@ -183,6 +183,15 @@ so the best that can be done here is to "outline" such nested testsets; with
`include=:static`, the subdmodules will get defined after `hijack` has
returned (on the first call to `retest` thereafter), so won't be "processed".
#### `include_functions` keyword
When the `include=:static` keyword argument is passed, it's possible to
tell `hijack` to apply the same treatment to other functions than
`include`, by passing a list a symbols to `include_functions`.
For example, if you defined a custom function `custom_include(x)`
which itself calls out to `include`, you can pass
`include_functions=[:custom_include]` to `hijack`.
#### `revise` keyword
The `revise` keyword specifies whether `Revise` should be used to track
Expand All @@ -201,7 +210,8 @@ function hijack end

function hijack(path::AbstractString, modname=nothing; parentmodule::Module=Main,
lazy=false, revise::Maybe{Bool}=nothing,
include::Maybe{Symbol}=nothing, testset::Bool=false)
include::Maybe{Symbol}=nothing, testset::Bool=false,
include_functions=[:include])

# do first, to error early if necessary
Revise = get_revise(revise)
Expand All @@ -213,6 +223,7 @@ function hijack(path::AbstractString, modname=nothing; parentmodule::Module=Main

newmod = @eval parentmodule module $modname end
populate_mod!(newmod, path; lazy=lazy, include=setinclude(include, testset),
include_functions=include_functions,
Revise=Revise)
newmod
end
Expand All @@ -234,12 +245,14 @@ const root_module = Ref{Symbol}()

__init__() = root_module[] = gensym("MODULE")

function populate_mod!(mod::Module, path; lazy, Revise, include::Maybe{Symbol}=nothing)
function populate_mod!(mod::Module, path; lazy, Revise, include::Maybe{Symbol}=nothing,
include_functions)
lazy (true, false, :brutal) ||
throw(ArgumentError("the `lazy` keyword must be `true`, `false` or `:brutal`"))

files = Revise === nothing ? nothing : Dict(path => mod)
substitute!(x) = substitute_retest!(x, lazy, include, files)
substitute!(x) = substitute_retest!(x, lazy, include, files;
include_functions=include_functions)

@eval mod begin
using ReTest # for files which don't have `using Test`
Expand Down Expand Up @@ -268,7 +281,8 @@ end

function hijack(packagemod::Module, modname=nothing; parentmodule::Module=Main,
lazy=false, revise::Maybe{Bool}=nothing,
include::Maybe{Symbol}=nothing, testset::Bool=false)
include::Maybe{Symbol}=nothing, testset::Bool=false,
include_functions=[:include])
packagepath = pathof(packagemod)
packagepath === nothing && packagemod !== Base &&
throw(ArgumentError("$packagemod is not a package"))
Expand All @@ -286,13 +300,15 @@ function hijack(packagemod::Module, modname=nothing; parentmodule::Module=Main,
else
path = joinpath(dirname(dirname(packagepath)), "test", "runtests.jl")
hijack(path, modname, parentmodule=parentmodule,
lazy=lazy, testset=testset, include=include, revise=revise)
lazy=lazy, testset=testset, include=include, revise=revise,
include_functions=include_functions)
end
end

function substitute_retest!(ex, lazy, include_::Maybe{Symbol}, files=nothing;
ishijack::Bool=true)
substitute!(x) = substitute_retest!(x, lazy, include_, files, ishijack=ishijack)
ishijack::Bool=true, include_functions::Vector{Symbol}=[:include])
substitute!(x) = substitute_retest!(x, lazy, include_, files, ishijack=ishijack,
include_functions=include_functions)

if Meta.isexpr(ex, :using)
ishijack || return ex
Expand Down Expand Up @@ -349,16 +365,28 @@ function substitute_retest!(ex, lazy, include_::Maybe{Symbol}, files=nothing;
if body.head == :for
body = body.args[end]
end
includes = splice!(body.args, findall(body.args) do x
Meta.isexpr(x, :call) && x.args[1] == :include
end)
includes = splice!(body.args,
findall(body.args) do x
Meta.isexpr(x, :call) && x.args[1] include_functions
end)
map!(substitute!, includes, includes)
ex.head = :block
newts = Expr(:macrocall, ex.args...)
push!(empty!(ex.args), newts, includes...)
else # :static
elseif include_ == :static
pos = ex.args[2] isa LineNumberNode ? 3 : 2
insert!(ex.args, pos, :(static_include=true))
if !isempty(include_functions)
# enabled only for :static currently, would need to change ex.args
# to newts.args in order to make it work for :outline
ifn_expr = Expr(:vect)
for ifn in include_functions
push!(ifn_expr.args, QuoteNode(ifn))
end
insert!(ex.args, length(ex.args)-1, :(include_functions=$ifn_expr))
end
else
@assert false
end
end
elseif ex isa Expr && ex.head (:block, :let, :for, :while, :if, :try)
Expand Down
4 changes: 3 additions & 1 deletion test/Hijack/test/include_static.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# include = :static
using Hijack, Test

custom_include_function(f) = include(f)

@testset "include_static" begin
@test true
push!(Hijack.RUN, 1)
include("include_static_included1.jl")
custom_include_function("include_static_included1.jl")
end
5 changes: 3 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ end
@chapter TestsetErrors begin
@test_logs (
:error, "expected begin/end block or for loop as argument to @testset") (
:error, "unsupported @testset option") (
:error, "unsupported @testset option: notexistingoption") (
:error, "unsupported @testset" ) (
:error, "expected begin/end block or for loop as argument to @testset") (
:error, "expected begin/end block or for loop as argument to @testset"
Expand Down Expand Up @@ -1862,7 +1862,8 @@ end
empty!(Hijack.RUN)
@test_throws ErrorException ReTest.hijack("./Hijack/test/include_static.jl", :HijackInclude, include=:static, testset=true)
@test_throws ErrorException ReTest.hijack("./Hijack/test/include_static.jl", :HijackInclude, include=:notvalid)
ReTest.hijack("./Hijack/test/include_static.jl", :HijackInclude, include=:static)
ReTest.hijack("./Hijack/test/include_static.jl", :HijackInclude, include=:static,
include_functions=[:include, :custom_include_function])
check(HijackInclude, dry=true, verbose=9, [], output="""
1| include_static
2| include_static_included1 1
Expand Down

0 comments on commit 4ca5f84

Please sign in to comment.