Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Piracy detection by considering more methods #156

Merged
merged 4 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 53 additions & 44 deletions src/piracy.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,58 @@
module Piracy

using Test: @test, @test_broken
using ..Aqua: walkmodules

const DEFAULT_PKGS = (Base.PkgId(Base), Base.PkgId(Core))

function all_methods!(
mod::Module,
done_callables::Base.IdSet{Any}, # cached to prevent duplicates
result::Vector{Method},
filter_default::Bool,
)::Vector{Method}
for name in names(mod; all = true, imported = true)
# names can list undefined symbols which cannot be eval'd
isdefined(mod, name) || continue

# Skip closures
startswith(String(name), "#") && continue
val = getfield(mod, name)

if !in(val, done_callables)
# In old versions of Julia, Vararg errors when methods is called on it
val === Vararg && continue
for method in methods(val)
# Default filtering removes all methods defined in DEFAULT_PKGs,
# since these may pirate each other.
if !(filter_default && in(Base.PkgId(method.module), DEFAULT_PKGS))
push!(result, method)
end
end
push!(done_callables, val)
if VERSION >= v"1.6-"
using Test: is_in_mods
else
function is_in_mods(m::Module, recursive::Bool, mods)
while true
m in mods && return true
recursive || return false
p = parentmodule(m)
p === m && return false
m = p
end
end
result
end

function all_methods(mod::Module; filter_default::Bool = true)
result = Method[]
done_callables = Base.IdSet()
walkmodules(mod) do mod
all_methods!(mod, done_callables, result, filter_default)
# based on Test/Test.jl#detect_ambiguities
# https://github.com/JuliaLang/julia/blob/v1.9.1/stdlib/Test/src/Test.jl#L1838-L1896
function all_methods(mods::Module...; skip_deprecated::Bool = true)
meths = Method[]
mods = collect(mods)::Vector{Module}

function examine(mt::Core.MethodTable)
examine(Base.MethodList(mt))
end
return result
function examine(ml::Base.MethodList)
for m in ml
is_in_mods(m.module, true, mods) || continue
push!(meths, m)
end
end

work = Base.loaded_modules_array()
filter!(mod -> mod === parentmodule(mod), work) # some items in loaded_modules_array are not top modules (really just Base)
while !isempty(work)
mod = pop!(work)
for name in names(mod; all = true)
(skip_deprecated && Base.isdeprecated(mod, name)) && continue
isdefined(mod, name) || continue
f = Base.unwrap_unionall(getfield(mod, name))
if isa(f, Module) && f !== mod && parentmodule(f) === mod && nameof(f) === name
push!(work, f)
elseif isa(f, DataType) &&
isdefined(f.name, :mt) &&
parentmodule(f) === mod &&
nameof(f) === name &&
f.name.mt !== Symbol.name.mt &&
f.name.mt !== DataType.name.mt
examine(f.name.mt)
end
end
end
examine(Symbol.name.mt)
examine(DataType.name.mt)
return meths
end

##################################
Expand Down Expand Up @@ -141,7 +152,7 @@ function is_foreign_method(@nospecialize(T::DataType), pkg::Base.PkgId; treat_as

# fallback to general code
return !(T in treat_as_own) &&
!(T <: Function && T.instance in treat_as_own) &&
!(T <: Function && isdefined(T, :instance) && T.instance in treat_as_own) &&
is_foreign(T, pkg; treat_as_own = treat_as_own)
end

Expand All @@ -162,12 +173,9 @@ function is_pirate(meth::Method; treat_as_own = Union{Function,Type}[])
)
end

hunt(mod::Module; from::Module = mod, kwargs...) =
hunt(Base.PkgId(mod); from = from, kwargs...)

function hunt(pkg::Base.PkgId; from::Module, kwargs...)
filter(all_methods(from)) do method
Base.PkgId(method.module) === pkg && is_pirate(method; kwargs...)
function hunt(mod::Module; skip_deprecated::Bool = true, kwargs...)
filter(all_methods(mod; skip_deprecated = skip_deprecated)) do method
method.module === mod && is_pirate(method; kwargs...)
end
end

Expand All @@ -182,6 +190,7 @@ See [Julia documentation](https://docs.julialang.org/en/v1/manual/style-guide/#A
# Keyword Arguments
- `broken::Bool = false`: If true, it uses `@test_broken` instead of
`@test`.
- `skip_deprecated::Bool = true`: If true, it does not check deprecated methods.
- `treat_as_own = Union{Function, Type}[]`: The types in this container
are considered to be "owned" by the module `m`. This is useful for
testing packages that deliberately commit some type piracy, e.g. modules
Expand Down
4 changes: 4 additions & 0 deletions test/pkgs/PiracyForeignProject/src/PiracyForeignProject.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ module PiracyForeignProject
struct ForeignType end
struct ForeignParameterizedType{T} end

struct ForeignNonSingletonType
x::Int
end

end
83 changes: 52 additions & 31 deletions test/test_piracy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ push!(LOAD_PATH, joinpath(@__DIR__, "pkgs", "PiracyForeignProject"))

baremodule PiracyModule

using PiracyForeignProject: ForeignType, ForeignParameterizedType
using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType

using Base:
Base,
Expand Down Expand Up @@ -44,6 +44,8 @@ export MyUnion
Base.findfirst(::Set{Vector{Char}}, ::Int) = 1
Base.findfirst(::Union{Foo,Bar{Set{Unsigned}},UInt}, ::Tuple{Vararg{String}}) = 1
Base.findfirst(::AbstractChar, ::Set{T}) where {Int <: T <: Integer} = 1
(::ForeignType)(x::Int8) = x + 1
(::ForeignNonSingletonType)(x::Int8) = x + 1

# Piracy, but not for `ForeignType in treat_as_own`
Base.findmax(::ForeignType, x::Int) = x + 1
Expand All @@ -55,29 +57,27 @@ Base.findmin(::ForeignParameterizedType{Int}, x::Int) = x + 1
Base.findmin(::Set{Vector{ForeignParameterizedType{Int}}}, x::Int) = x + 1
Base.findmin(::Union{Foo,ForeignParameterizedType{Int}}, x::Int) = x + 1

# Assign them names in this module so they can be found by all_methods
a = Base.findfirst
b = Base.findlast
c = Base.findmax
d = Base.findmin
end # PiracyModule

using Aqua: Piracy
using PiracyForeignProject: ForeignType, ForeignParameterizedType
using PiracyForeignProject: ForeignType, ForeignParameterizedType, ForeignNonSingletonType

# Get all methods - test length
meths = filter(Piracy.all_methods(PiracyModule)) do m
m.module == PiracyModule
end

# 2 Foo constructors
# 2 from f
# 1 from MyUnion
# 6 from findlast
# 3 from findfirst
# 3 from findmax
# 3 from findmin
@test length(meths) == 2 + 2 + 1 + 6 + 3 + 3 + 3
@test length(meths) ==
2 + # Foo constructors
1 + # Bar constructor
2 + # f
1 + # MyUnion
6 + # findlast
3 + # findfirst
1 + # ForeignType callable
1 + # ForeignNonSingletonType callable
3 + # findmax
3 # findmin

# Test what is foreign
BasePkg = Base.PkgId(Base)
Expand All @@ -90,49 +90,70 @@ ThisPkg = Base.PkgId(PiracyModule)
@test !Piracy.is_foreign(Set{Int}, CorePkg; treat_as_own = [])

# Test what is pirate
pirates = filter(m -> Piracy.is_pirate(m), meths)
@test length(pirates) == 3 + 3 + 3
pirates = Piracy.hunt(PiracyModule)
@test length(pirates) ==
3 + # findfirst
3 + # findmax
3 + # findmin
1 + # ForeignType callable
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findfirst, :findmax, :findmin]
m.name in [:findfirst, :findmax, :findmin, :ForeignType, :ForeignNonSingletonType]
end

# Test what is pirate (with treat_as_own=[ForeignType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignType]), meths)
@test length(pirates) == 3 + 3
pirates = Piracy.hunt(PiracyModule, treat_as_own = [ForeignType])
@test length(pirates) ==
3 + # findfirst
3 + # findmin
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findfirst, :findmin]
m.name in [:findfirst, :findmin, :ForeignNonSingletonType]
end

# Test what is pirate (with treat_as_own=[ForeignParameterizedType])
pirates = filter(m -> Piracy.is_pirate(m; treat_as_own = [ForeignParameterizedType]), meths)
@test length(pirates) == 3 + 3
pirates = Piracy.hunt(PiracyModule, treat_as_own = [ForeignParameterizedType])
@test length(pirates) ==
3 + # findfirst
3 + # findmax
1 + # ForeignType callable
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findfirst, :findmax]
m.name in [:findfirst, :findmax, :ForeignType, :ForeignNonSingletonType]
end

# Test what is pirate (with treat_as_own=[ForeignType, ForeignParameterizedType])
pirates = filter(
m -> Piracy.is_pirate(m; treat_as_own = [ForeignType, ForeignParameterizedType]),
meths,
)
@test length(pirates) == 3
@test length(pirates) ==
3 + # findfirst
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findfirst]
m.name in [:findfirst, :ForeignNonSingletonType]
end

# Test what is pirate (with treat_as_own=[Base.findfirst, Base.findmax])
pirates =
filter(m -> Piracy.is_pirate(m; treat_as_own = [Base.findfirst, Base.findmax]), meths)
@test length(pirates) == 3
pirates = Piracy.hunt(PiracyModule, treat_as_own = [Base.findfirst, Base.findmax])
@test length(pirates) ==
3 + # findmin
1 + # ForeignType callable
1 # ForeignNonSingletonType callable
@test all(pirates) do m
m.name in [:findmin]
m.name in [:findmin, :ForeignType, :ForeignNonSingletonType]
end

# Test what is pirate (excluding a cover of everything)
pirates = filter(
m -> Piracy.is_pirate(
m;
treat_as_own = [ForeignType, ForeignParameterizedType, Base.findfirst],
treat_as_own = [
ForeignType,
ForeignParameterizedType,
ForeignNonSingletonType,
Base.findfirst,
],
),
meths,
)
Expand Down