From 162bd0d9ae952347980a3e2f7e6561e5170806bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefanos=20Carlstr=C3=B6m?= Date: Tue, 19 May 2020 16:18:53 +0200 Subject: [PATCH] Implemented findall(in(interval), x::AbstractRange), fixes #52 (#63) --- Project.toml | 4 +- src/IntervalSets.jl | 1 + src/findall.jl | 85 +++++++++++++++++++++++++++++ test/findall.jl | 127 ++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 3 ++ 5 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 src/findall.jl create mode 100644 test/findall.jl diff --git a/Project.toml b/Project.toml index 326b16a..c0c5f5a 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,8 @@ julia = "0.7, 1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" [targets] -test = ["Test"] +test = ["Test", "Random", "OffsetArrays"] diff --git a/src/IntervalSets.jl b/src/IntervalSets.jl index e61b39f..e5d9298 100644 --- a/src/IntervalSets.jl +++ b/src/IntervalSets.jl @@ -272,5 +272,6 @@ convert(::Type{ClosedInterval}, x::Number) = x..x convert(::Type{ClosedInterval{T}}, x::Number) where T = convert(AbstractInterval{T}, convert(AbstractInterval, x)) +include("findall.jl") end # module diff --git a/src/findall.jl b/src/findall.jl new file mode 100644 index 0000000..9bc05c0 --- /dev/null +++ b/src/findall.jl @@ -0,0 +1,85 @@ +""" + findall(in(interval), x::AbstractRange) + +Return all indices `i` for which `x[i] ∈ interval`, specialized for +the case where `x` is a range, which enables constant-time complexity. + +# Examples + +```jldoctest +julia> x = range(0,stop=3,length=10) +0.0:0.3333333333333333:3.0 + +julia> collect(x)' +1×10 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}: + 0.0 0.333333 0.666667 1.0 1.33333 1.66667 2.0 2.33333 2.66667 3.0 + +julia> findall(in(1..6), x) +4:10 +``` + +It also works for decreasing ranges: +```jldoctest +julia> y = 8:-0.5:0 +8.0:-0.5:0.0 + +julia> collect(y)' +1×17 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}: + 8.0 7.5 7.0 6.5 6.0 5.5 5.0 4.5 4.0 3.5 3.0 2.5 2.0 1.5 1.0 0.5 0.0 + +julia> findall(in(1..6), y) +5:15 + +julia> findall(in(Interval{:open,:closed}(1,6)), y) # (1,6], does not include 1 +5:14 +``` +""" +function Base.findall(interval_d::Base.Fix2{typeof(in),Interval{L,R,T}}, x::AbstractRange) where {L,R,T} + isempty(x) && return 1:0 + + interval = interval_d.x + il, ir = firstindex(x), lastindex(x) + δx = step(x) + a,b = if δx < 0 + rev = findall(in(interval), reverse(x)) + isempty(rev) && return rev + + a = (il+ir)-last(rev) + b = (il+ir)-first(rev) + + a,b + else + lx, rx = first(x), last(x) + l = max(leftendpoint(interval), lx-1) + r = min(rightendpoint(interval), rx+1) + + (l > rx || r < lx) && return 1:0 + + a = il + max(0, round(Int, cld(l-lx, δx))) + a += (a ≤ ir && (x[a] == l && L == :open || x[a] < l)) + + b = min(ir, round(Int, cld(r-lx, δx)) + il) + b -= (b ≥ il && (x[b] == r && R == :open || x[b] > r)) + + a,b + end + # Reversing a range could change sign of values close to zero (cf + # sign of the smallest element in x and reverse(x), where x = + # range(BigFloat(-0.5),stop=BigFloat(1.0),length=10)), or more + # generally push elements in or out of the interval (as can cld), + # so we need to check once again. + a += +(a < ir && x[a] ∉ interval) - (il < a && x[a-1] ∈ interval) + b += -(il < b && x[b] ∉ interval) + (b < ir && x[b+1] ∈ interval) + + a:b +end + +# We overload Base._findin to avoid an ambiguity that arises with +# Base.findall(interval_d::Base.Fix2{typeof(in),Interval{L,R,T}}, x::AbstractArray) +function Base._findin(a::Union{AbstractArray, Tuple}, b::Interval) + ind = Vector{eltype(keys(a))}() + @inbounds for (i,ai) in pairs(a) + ai in b && push!(ind, i) + end + ind +end diff --git a/test/findall.jl b/test/findall.jl new file mode 100644 index 0000000..638d2f8 --- /dev/null +++ b/test/findall.jl @@ -0,0 +1,127 @@ +using OffsetArrays + +# Helper function to test that findall(in(interval), x) works. By +# default, a reference is generated using the general algorithm, +# linear in complexity, by generating a vector with the same contents +# as x. +function assert_in_interval(x, interval, + expected=findall(v -> v ∈ interval, x)) + + result = :(findall(in($interval), $x)) + expr = :($result == $expected || isempty($result) && isempty($expected)) + if !(@eval $expr) + println("Looking for elements of $x ∈ $interval, got $(@eval $result), expected $expected") + length(x) < 30 && println(" x = ", collect(pairs(x)), "\n") + end + @eval @test $expr +end + +@testset "Interval coverage" begin + @testset "Basic tests" begin + let x = range(0, stop=1, length=21) + Random.seed!(321) + @testset "$kind" for (kind,end_points) in [ + ("Two intervals", [(0.0, 0.5), (0.25,0.5)]), + ("Three intervals", [(0, 1/3), (1/3, 2/3), (2/3, 1)]), + ("Random intervals", [minmax(rand(),rand()) for i = 1:2]), + ("Interval containing one point", [(0.4619303378979984,0.5450937144417902)]), + ("Interval containing no points", [(0.9072957410215778,0.9082803807133988)]) + ] + @testset "L=$L" for L=[:closed,:open] + @testset "R=$R" for R=[:closed,:open] + for (a,b) in end_points + interval = Interval{L,R}(a, b) + @testset "Reversed: $reversed" for reversed in [false, true] + assert_in_interval(reversed ? reverse(x) : x, interval) + end + end + end + end + end + + @testset "Open interval" begin + assert_in_interval(x, OpenInterval(0.2,0.4), 6:8) + end + end + end + + @testset "Partially covered intervals" begin + @testset "$T" for T in (Float32,Float64,BigFloat) + @testset "$name, x = $x" for (name,x) in [ + ("Outside left",range(T(-1),stop=T(-0.5),length=10)), + ("Touching left",range(T(-1),stop=T(0),length=10)), + ("Touching left-ϵ",range(T(-1),stop=T(0)-eps(T),length=10)), + ("Touching left+ϵ",range(T(-1),stop=T(0)+eps(T),length=10)), + + ("Outside right",range(T(1.5),stop=T(2),length=10)), + ("Touching right",range(T(1),stop=T(2),length=10)), + ("Touching right-ϵ",range(T(1)-eps(T),stop=T(2),length=10)), + ("Touching right+ϵ",range(T(1)+eps(T),stop=T(2),length=10)), + + ("Other right",range(T(0.5),stop=T(1),length=10)), + ("Other right-ϵ",range(T(0.5)-eps(T(0.5)),stop=T(1),length=10)), + ("Other right+ϵ",range(T(0.5)+eps(T(0.5)),stop=T(1),length=10)), + + ("Complete", range(T(0),stop=T(1),length=10)), + ("Complete-ϵ", range(eps(T),stop=T(1)-eps(T),length=10)), + ("Complete+ϵ", range(-eps(T),stop=T(1)+eps(T),length=10)), + + ("Left partial", range(T(-0.5),stop=T(0.6),length=10)), + ("Left", range(T(-0.5),stop=T(1.0),length=10)), + ("Right partial", range(T(0.5),stop=T(1.6),length=10)), + ("Right", range(T(0),stop=T(1.6),length=10))] + @testset "L=$L" for L=[:closed,:open] + @testset "R=$R" for R=[:closed,:open] + @testset "Reversed: $reversed" for reversed in [false, true] + for (a,b) in [(T(0.0),T(0.5)),(T(0.5),T(1.0))] + interval = Interval{L,R}(a, b) + assert_in_interval(reversed ? reverse(x) : x, interval) + end + end + end + end + end + end + end + + @testset "Large intervals" begin + @test findall(in(4..Inf), 2:2:10) == 2:5 + @test findall(in(4..1e20), 2:2:10) == 2:5 + @test isempty(findall(in(-Inf..(-1e20)), 2:2:10)) + end + + @testset "Reverse intervals" begin + for x in [1:10, 1:3:10, 2:3:11, -1:9, -2:0.5:5] + for lo in -3:4, hi in 5:13 + for L in [:closed, :open], R in [:closed, :open] + interval = Interval{L,R}(lo,hi) + assert_in_interval(x, interval) + assert_in_interval(reverse(x), interval) + end + end + end + end + + @testset "Arrays" begin + @test findall(in(1..6), collect(0:7)) == 2:7 + @test findall(in(1..6), reshape(1:16, 4, 4)) == + vcat([CartesianIndex(i,1) for i = 1:4], CartesianIndex(1,2), CartesianIndex(2,2)) + end + + @testset "Empty ranges and intervals" begin + # Range empty + @test isempty(findall(in(1..6), 1:0)) + # Interval empty + @test isempty(findall(in(Interval{:closed,:open}(1.0..1.0)), + 0.0:0.02040816326530612:1.0)) + end + + @testset "Offset arrays" begin + for (x,interval) in [(OffsetArray(ones(10), -5), -1..1), + (OffsetArray(1:5, -3), 2..4), + (OffsetArray(5:-1:1, -5), 2..4)] + assert_in_interval(x, interval) + assert_in_interval(reverse(x), interval) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e4df502..46d97cb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test using Dates using Statistics import Statistics: mean +using Random import IntervalSets: Domain, endpoints, closedendpoints, TypedEndpointsInterval @@ -706,4 +707,6 @@ struct IncompleteInterval <: AbstractInterval{Int} end @test_throws ErrorException endpoints(I) @test_throws ErrorException closedendpoints(I) end + + include("findall.jl") end