-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
- Loading branch information
Showing
5 changed files
with
219 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
findall(in(interval), x::AbstractRange) | ||
Return all indices `i` for which `x[i] ∈ interval`, specialized for | ||
the case where `x` is a range, which enables constant-time complexity. | ||
# Examples | ||
```jldoctest | ||
julia> x = range(0,stop=3,length=10) | ||
0.0:0.3333333333333333:3.0 | ||
julia> collect(x)' | ||
1×10 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}: | ||
0.0 0.333333 0.666667 1.0 1.33333 1.66667 2.0 2.33333 2.66667 3.0 | ||
julia> findall(in(1..6), x) | ||
4:10 | ||
``` | ||
It also works for decreasing ranges: | ||
```jldoctest | ||
julia> y = 8:-0.5:0 | ||
8.0:-0.5:0.0 | ||
julia> collect(y)' | ||
1×17 LinearAlgebra.Adjoint{Float64,Array{Float64,1}}: | ||
8.0 7.5 7.0 6.5 6.0 5.5 5.0 4.5 4.0 3.5 3.0 2.5 2.0 1.5 1.0 0.5 0.0 | ||
julia> findall(in(1..6), y) | ||
5:15 | ||
julia> findall(in(Interval{:open,:closed}(1,6)), y) # (1,6], does not include 1 | ||
5:14 | ||
``` | ||
""" | ||
function Base.findall(interval_d::Base.Fix2{typeof(in),Interval{L,R,T}}, x::AbstractRange) where {L,R,T} | ||
isempty(x) && return 1:0 | ||
|
||
interval = interval_d.x | ||
il, ir = firstindex(x), lastindex(x) | ||
δx = step(x) | ||
a,b = if δx < 0 | ||
rev = findall(in(interval), reverse(x)) | ||
isempty(rev) && return rev | ||
|
||
a = (il+ir)-last(rev) | ||
b = (il+ir)-first(rev) | ||
|
||
a,b | ||
else | ||
lx, rx = first(x), last(x) | ||
l = max(leftendpoint(interval), lx-1) | ||
r = min(rightendpoint(interval), rx+1) | ||
|
||
(l > rx || r < lx) && return 1:0 | ||
|
||
a = il + max(0, round(Int, cld(l-lx, δx))) | ||
a += (a ≤ ir && (x[a] == l && L == :open || x[a] < l)) | ||
|
||
b = min(ir, round(Int, cld(r-lx, δx)) + il) | ||
b -= (b ≥ il && (x[b] == r && R == :open || x[b] > r)) | ||
|
||
a,b | ||
end | ||
# Reversing a range could change sign of values close to zero (cf | ||
# sign of the smallest element in x and reverse(x), where x = | ||
# range(BigFloat(-0.5),stop=BigFloat(1.0),length=10)), or more | ||
# generally push elements in or out of the interval (as can cld), | ||
# so we need to check once again. | ||
a += +(a < ir && x[a] ∉ interval) - (il < a && x[a-1] ∈ interval) | ||
b += -(il < b && x[b] ∉ interval) + (b < ir && x[b+1] ∈ interval) | ||
|
||
a:b | ||
end | ||
|
||
# We overload Base._findin to avoid an ambiguity that arises with | ||
# Base.findall(interval_d::Base.Fix2{typeof(in),Interval{L,R,T}}, x::AbstractArray) | ||
function Base._findin(a::Union{AbstractArray, Tuple}, b::Interval) | ||
ind = Vector{eltype(keys(a))}() | ||
@inbounds for (i,ai) in pairs(a) | ||
ai in b && push!(ind, i) | ||
end | ||
ind | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
using OffsetArrays | ||
|
||
# Helper function to test that findall(in(interval), x) works. By | ||
# default, a reference is generated using the general algorithm, | ||
# linear in complexity, by generating a vector with the same contents | ||
# as x. | ||
function assert_in_interval(x, interval, | ||
expected=findall(v -> v ∈ interval, x)) | ||
|
||
result = :(findall(in($interval), $x)) | ||
expr = :($result == $expected || isempty($result) && isempty($expected)) | ||
if !(@eval $expr) | ||
println("Looking for elements of $x ∈ $interval, got $(@eval $result), expected $expected") | ||
length(x) < 30 && println(" x = ", collect(pairs(x)), "\n") | ||
end | ||
@eval @test $expr | ||
end | ||
|
||
@testset "Interval coverage" begin | ||
@testset "Basic tests" begin | ||
let x = range(0, stop=1, length=21) | ||
Random.seed!(321) | ||
@testset "$kind" for (kind,end_points) in [ | ||
("Two intervals", [(0.0, 0.5), (0.25,0.5)]), | ||
("Three intervals", [(0, 1/3), (1/3, 2/3), (2/3, 1)]), | ||
("Random intervals", [minmax(rand(),rand()) for i = 1:2]), | ||
("Interval containing one point", [(0.4619303378979984,0.5450937144417902)]), | ||
("Interval containing no points", [(0.9072957410215778,0.9082803807133988)]) | ||
] | ||
@testset "L=$L" for L=[:closed,:open] | ||
@testset "R=$R" for R=[:closed,:open] | ||
for (a,b) in end_points | ||
interval = Interval{L,R}(a, b) | ||
@testset "Reversed: $reversed" for reversed in [false, true] | ||
assert_in_interval(reversed ? reverse(x) : x, interval) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
@testset "Open interval" begin | ||
assert_in_interval(x, OpenInterval(0.2,0.4), 6:8) | ||
end | ||
end | ||
end | ||
|
||
@testset "Partially covered intervals" begin | ||
@testset "$T" for T in (Float32,Float64,BigFloat) | ||
@testset "$name, x = $x" for (name,x) in [ | ||
("Outside left",range(T(-1),stop=T(-0.5),length=10)), | ||
("Touching left",range(T(-1),stop=T(0),length=10)), | ||
("Touching left-ϵ",range(T(-1),stop=T(0)-eps(T),length=10)), | ||
("Touching left+ϵ",range(T(-1),stop=T(0)+eps(T),length=10)), | ||
|
||
("Outside right",range(T(1.5),stop=T(2),length=10)), | ||
("Touching right",range(T(1),stop=T(2),length=10)), | ||
("Touching right-ϵ",range(T(1)-eps(T),stop=T(2),length=10)), | ||
("Touching right+ϵ",range(T(1)+eps(T),stop=T(2),length=10)), | ||
|
||
("Other right",range(T(0.5),stop=T(1),length=10)), | ||
("Other right-ϵ",range(T(0.5)-eps(T(0.5)),stop=T(1),length=10)), | ||
("Other right+ϵ",range(T(0.5)+eps(T(0.5)),stop=T(1),length=10)), | ||
|
||
("Complete", range(T(0),stop=T(1),length=10)), | ||
("Complete-ϵ", range(eps(T),stop=T(1)-eps(T),length=10)), | ||
("Complete+ϵ", range(-eps(T),stop=T(1)+eps(T),length=10)), | ||
|
||
("Left partial", range(T(-0.5),stop=T(0.6),length=10)), | ||
("Left", range(T(-0.5),stop=T(1.0),length=10)), | ||
("Right partial", range(T(0.5),stop=T(1.6),length=10)), | ||
("Right", range(T(0),stop=T(1.6),length=10))] | ||
@testset "L=$L" for L=[:closed,:open] | ||
@testset "R=$R" for R=[:closed,:open] | ||
@testset "Reversed: $reversed" for reversed in [false, true] | ||
for (a,b) in [(T(0.0),T(0.5)),(T(0.5),T(1.0))] | ||
interval = Interval{L,R}(a, b) | ||
assert_in_interval(reversed ? reverse(x) : x, interval) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
end | ||
end | ||
|
||
@testset "Large intervals" begin | ||
@test findall(in(4..Inf), 2:2:10) == 2:5 | ||
@test findall(in(4..1e20), 2:2:10) == 2:5 | ||
@test isempty(findall(in(-Inf..(-1e20)), 2:2:10)) | ||
end | ||
|
||
@testset "Reverse intervals" begin | ||
for x in [1:10, 1:3:10, 2:3:11, -1:9, -2:0.5:5] | ||
for lo in -3:4, hi in 5:13 | ||
for L in [:closed, :open], R in [:closed, :open] | ||
interval = Interval{L,R}(lo,hi) | ||
assert_in_interval(x, interval) | ||
assert_in_interval(reverse(x), interval) | ||
end | ||
end | ||
end | ||
end | ||
|
||
@testset "Arrays" begin | ||
@test findall(in(1..6), collect(0:7)) == 2:7 | ||
@test findall(in(1..6), reshape(1:16, 4, 4)) == | ||
vcat([CartesianIndex(i,1) for i = 1:4], CartesianIndex(1,2), CartesianIndex(2,2)) | ||
end | ||
|
||
@testset "Empty ranges and intervals" begin | ||
# Range empty | ||
@test isempty(findall(in(1..6), 1:0)) | ||
# Interval empty | ||
@test isempty(findall(in(Interval{:closed,:open}(1.0..1.0)), | ||
0.0:0.02040816326530612:1.0)) | ||
end | ||
|
||
@testset "Offset arrays" begin | ||
for (x,interval) in [(OffsetArray(ones(10), -5), -1..1), | ||
(OffsetArray(1:5, -3), 2..4), | ||
(OffsetArray(5:-1:1, -5), 2..4)] | ||
assert_in_interval(x, interval) | ||
assert_in_interval(reverse(x), interval) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters