Skip to content

Commit

Permalink
Implemented findall(in(interval), x::AbstractRange), fixes #52 (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
jagot authored May 19, 2020
1 parent 9c7cc1c commit 162bd0d
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ julia = "0.7, 1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

[targets]
test = ["Test"]
test = ["Test", "Random", "OffsetArrays"]
1 change: 1 addition & 0 deletions src/IntervalSets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,5 +272,6 @@ convert(::Type{ClosedInterval}, x::Number) = x..x
convert(::Type{ClosedInterval{T}}, x::Number) where T =
convert(AbstractInterval{T}, convert(AbstractInterval, x))

include("findall.jl")

end # module
85 changes: 85 additions & 0 deletions src/findall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
findall(in(interval), x::AbstractRange)
Return all indices `i` for which `x[i] ∈ interval`, specialized for
the case where `x` is a range, which enables constant-time complexity.
# Examples
```jldoctest
julia> x = range(0,stop=3,length=10)
0.0:0.3333333333333333:3.0
julia> collect(x)'
1×10 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}:
0.0 0.333333 0.666667 1.0 1.33333 1.66667 2.0 2.33333 2.66667 3.0
julia> findall(in(1..6), x)
4:10
```
It also works for decreasing ranges:
```jldoctest
julia> y = 8:-0.5:0
8.0:-0.5:0.0
julia> collect(y)'
1×17 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}:
8.0 7.5 7.0 6.5 6.0 5.5 5.0 4.5 4.0 3.5 3.0 2.5 2.0 1.5 1.0 0.5 0.0
julia> findall(in(1..6), y)
5:15
julia> findall(in(Interval{:open,:closed}(1,6)), y) # (1,6], does not include 1
5:14
```
"""
function Base.findall(interval_d::Base.Fix2{typeof(in),Interval{L,R,T}}, x::AbstractRange) where {L,R,T}
isempty(x) && return 1:0

interval = interval_d.x
il, ir = firstindex(x), lastindex(x)
δx = step(x)
a,b = if δx < 0
rev = findall(in(interval), reverse(x))
isempty(rev) && return rev

a = (il+ir)-last(rev)
b = (il+ir)-first(rev)

a,b
else
lx, rx = first(x), last(x)
l = max(leftendpoint(interval), lx-1)
r = min(rightendpoint(interval), rx+1)

(l > rx || r < lx) && return 1:0

a = il + max(0, round(Int, cld(l-lx, δx)))
a += (a ir && (x[a] == l && L == :open || x[a] < l))

b = min(ir, round(Int, cld(r-lx, δx)) + il)
b -= (b il && (x[b] == r && R == :open || x[b] > r))

a,b
end
# Reversing a range could change sign of values close to zero (cf
# sign of the smallest element in x and reverse(x), where x =
# range(BigFloat(-0.5),stop=BigFloat(1.0),length=10)), or more
# generally push elements in or out of the interval (as can cld),
# so we need to check once again.
a += +(a < ir && x[a] interval) - (il < a && x[a-1] interval)
b += -(il < b && x[b] interval) + (b < ir && x[b+1] interval)

a:b
end

# We overload Base._findin to avoid an ambiguity that arises with
# Base.findall(interval_d::Base.Fix2{typeof(in),Interval{L,R,T}}, x::AbstractArray)
function Base._findin(a::Union{AbstractArray, Tuple}, b::Interval)
ind = Vector{eltype(keys(a))}()
@inbounds for (i,ai) in pairs(a)
ai in b && push!(ind, i)
end
ind
end
127 changes: 127 additions & 0 deletions test/findall.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
using OffsetArrays

# Helper function to test that findall(in(interval), x) works. By
# default, a reference is generated using the general algorithm,
# linear in complexity, by generating a vector with the same contents
# as x.
function assert_in_interval(x, interval,
expected=findall(v -> v interval, x))

result = :(findall(in($interval), $x))
expr = :($result == $expected || isempty($result) && isempty($expected))
if !(@eval $expr)
println("Looking for elements of $x$interval, got $(@eval $result), expected $expected")
length(x) < 30 && println(" x = ", collect(pairs(x)), "\n")
end
@eval @test $expr
end

@testset "Interval coverage" begin
@testset "Basic tests" begin
let x = range(0, stop=1, length=21)
Random.seed!(321)
@testset "$kind" for (kind,end_points) in [
("Two intervals", [(0.0, 0.5), (0.25,0.5)]),
("Three intervals", [(0, 1/3), (1/3, 2/3), (2/3, 1)]),
("Random intervals", [minmax(rand(),rand()) for i = 1:2]),
("Interval containing one point", [(0.4619303378979984,0.5450937144417902)]),
("Interval containing no points", [(0.9072957410215778,0.9082803807133988)])
]
@testset "L=$L" for L=[:closed,:open]
@testset "R=$R" for R=[:closed,:open]
for (a,b) in end_points
interval = Interval{L,R}(a, b)
@testset "Reversed: $reversed" for reversed in [false, true]
assert_in_interval(reversed ? reverse(x) : x, interval)
end
end
end
end
end

@testset "Open interval" begin
assert_in_interval(x, OpenInterval(0.2,0.4), 6:8)
end
end
end

@testset "Partially covered intervals" begin
@testset "$T" for T in (Float32,Float64,BigFloat)
@testset "$name, x = $x" for (name,x) in [
("Outside left",range(T(-1),stop=T(-0.5),length=10)),
("Touching left",range(T(-1),stop=T(0),length=10)),
("Touching left-ϵ",range(T(-1),stop=T(0)-eps(T),length=10)),
("Touching left+ϵ",range(T(-1),stop=T(0)+eps(T),length=10)),

("Outside right",range(T(1.5),stop=T(2),length=10)),
("Touching right",range(T(1),stop=T(2),length=10)),
("Touching right-ϵ",range(T(1)-eps(T),stop=T(2),length=10)),
("Touching right+ϵ",range(T(1)+eps(T),stop=T(2),length=10)),

("Other right",range(T(0.5),stop=T(1),length=10)),
("Other right-ϵ",range(T(0.5)-eps(T(0.5)),stop=T(1),length=10)),
("Other right+ϵ",range(T(0.5)+eps(T(0.5)),stop=T(1),length=10)),

("Complete", range(T(0),stop=T(1),length=10)),
("Complete-ϵ", range(eps(T),stop=T(1)-eps(T),length=10)),
("Complete+ϵ", range(-eps(T),stop=T(1)+eps(T),length=10)),

("Left partial", range(T(-0.5),stop=T(0.6),length=10)),
("Left", range(T(-0.5),stop=T(1.0),length=10)),
("Right partial", range(T(0.5),stop=T(1.6),length=10)),
("Right", range(T(0),stop=T(1.6),length=10))]
@testset "L=$L" for L=[:closed,:open]
@testset "R=$R" for R=[:closed,:open]
@testset "Reversed: $reversed" for reversed in [false, true]
for (a,b) in [(T(0.0),T(0.5)),(T(0.5),T(1.0))]
interval = Interval{L,R}(a, b)
assert_in_interval(reversed ? reverse(x) : x, interval)
end
end
end
end
end
end
end

@testset "Large intervals" begin
@test findall(in(4..Inf), 2:2:10) == 2:5
@test findall(in(4..1e20), 2:2:10) == 2:5
@test isempty(findall(in(-Inf..(-1e20)), 2:2:10))
end

@testset "Reverse intervals" begin
for x in [1:10, 1:3:10, 2:3:11, -1:9, -2:0.5:5]
for lo in -3:4, hi in 5:13
for L in [:closed, :open], R in [:closed, :open]
interval = Interval{L,R}(lo,hi)
assert_in_interval(x, interval)
assert_in_interval(reverse(x), interval)
end
end
end
end

@testset "Arrays" begin
@test findall(in(1..6), collect(0:7)) == 2:7
@test findall(in(1..6), reshape(1:16, 4, 4)) ==
vcat([CartesianIndex(i,1) for i = 1:4], CartesianIndex(1,2), CartesianIndex(2,2))
end

@testset "Empty ranges and intervals" begin
# Range empty
@test isempty(findall(in(1..6), 1:0))
# Interval empty
@test isempty(findall(in(Interval{:closed,:open}(1.0..1.0)),
0.0:0.02040816326530612:1.0))
end

@testset "Offset arrays" begin
for (x,interval) in [(OffsetArray(ones(10), -5), -1..1),
(OffsetArray(1:5, -3), 2..4),
(OffsetArray(5:-1:1, -5), 2..4)]
assert_in_interval(x, interval)
assert_in_interval(reverse(x), interval)
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test
using Dates
using Statistics
import Statistics: mean
using Random

import IntervalSets: Domain, endpoints, closedendpoints, TypedEndpointsInterval

Expand Down Expand Up @@ -706,4 +707,6 @@ struct IncompleteInterval <: AbstractInterval{Int} end
@test_throws ErrorException endpoints(I)
@test_throws ErrorException closedendpoints(I)
end

include("findall.jl")
end

0 comments on commit 162bd0d

Please sign in to comment.