Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added tests for Static. #69

Merged
merged 12 commits into from
Sep 9, 2020
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Requires = "0.5, 1.0"
julia = "1.2"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Expand All @@ -21,4 +22,4 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random"]
test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "Aqua"]
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ Otherwise, returns `nothing`. For example, `known_step(UnitRange{Int})` returns
If `length` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.

## Static(N::Int)

Creates a static integer with value known at compile time. It is a number,
supporting basic arithmetic. Many operations with two `Static` integers
will produce another `Static` integer. If one of the arguments to a
function call isn't static (e.g., `Static(4) + 3`) then the `Static`
number will promote to a dynamic value.

# List of things to add

- https://github.com/JuliaLang/julia/issues/22216
Expand Down
1 change: 1 addition & 0 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ function __init__()
end
end

include("static.jl")
include("ranges.jl")

end
60 changes: 32 additions & 28 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)

_get(x) = x
_get(::Val{V}) where {V} = V
chriselrod marked this conversation as resolved.
Show resolved Hide resolved
_get(::Static{V}) where {V} = V
_get(::Type{Static{V}}) where {V} = V
_convert(::Type{T}, x) where {T} = convert(T, x)
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))

Expand All @@ -56,7 +58,7 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in
from other valid indices. Therefore, users should not expect the same checks are used
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
"""
struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}
struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T}
start::F
stop::L

Expand All @@ -79,28 +81,26 @@ struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T}

function OptionallyStaticUnitRange(x::AbstractRange)
if step(x) == 1
fst = known_first(x)
fst = fst === nothing ? first(x) : Val(fst)
lst = known_last(x)
lst = lst === nothing ? last(x) : Val(lst)
fst = static_first(x)
lst = static_last(x)
return OptionallyStaticUnitRange(fst, lst)
else
throw(ArgumentError("step must be 1, got $(step(r))"))
end
end
end

Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F
Base.first(r::OptionallyStaticUnitRange{<:Any,<:Any}) = r.start
Base.:(:)(L::Integer, ::Static{U}) where {U} = OptionallyStaticUnitRange(L, Static(U))
Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(L), U)
Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U))

Base.first(r::OptionallyStaticUnitRange) = r.start
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
Base.last(r::OptionallyStaticUnitRange) = r.stop

Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}) where {L} = L
Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,<:Any}) = r.stop

known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Val{F}}}) where {F} = F
known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L

function Base.isempty(r::OptionallyStaticUnitRange)
if known_first(r) === oneunit(eltype(r))
Expand Down Expand Up @@ -141,10 +141,20 @@ end
return convert(eltype(r), val)
end

_try_static(x, y) = Val(x)
_try_static(::Nothing, y) = Val(y)
_try_static(x, ::Nothing) = Val(x)
_try_static(::Nothing, ::Nothing) = nothing
@inline _try_static(::Static{N}, ::Static{N}) where {N} = Static{N}()
@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()"
function _try_static(::Static{N}, x) where {N}
@assert N == x "Unequal Indices: Static{$N}() != x == $x"
Static{N}()
end
function _try_static(x, ::Static{N}) where {N}
@assert N == x "Unequal Indices: x == $x != Static{$N}()"
Static{N}()
end
function _try_static(x, y)
@assert x == y "Unequal Indicess: x == $x != $y == y"
x
end

###
### length
Expand Down Expand Up @@ -193,7 +203,7 @@ specified then indices for visiting each index of `x` is returned.
"""
@inline function indices(x)
inds = eachindex(x)
if inds isa AbstractUnitRange{<:Integer}
if inds isa AbstractUnitRange#{<:Integer} # prevents inference
chriselrod marked this conversation as resolved.
Show resolved Hide resolved
return Base.Slice(OptionallyStaticUnitRange(inds))
else
return inds
Expand All @@ -202,30 +212,24 @@ end

function indices(x::Tuple)
inds = map(eachindex, x)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

indices(x, d) = indices(axes(x, d))
@inline indices(x, d) = indices(axes(x, d))

@inline function indices(x::NTuple{N,<:Any}, dim) where {N}
@inline function indices(x::Tuple{Vararg{Any,N}}, dim) where {N}
inds = map(x_i -> indices(x_i, dim), x)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

@inline function indices(x::NTuple{N,<:Any}, dim::NTuple{N,<:Any}) where {N}
@inline function indices(x::Tuple{Vararg{Any,N}}, dim::Tuple{Vararg{Any,N}}) where {N}
inds = map(indices, x, dim)
@assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds"
return reduce(_pick_range, inds)
end

@inline function _pick_range(x, y)
fst = _try_static(known_first(x), known_first(y))
fst = fst === nothing ? first(x) : fst

lst = _try_static(known_last(x), known_last(y))
lst = lst === nothing ? last(x) : lst
fst = _try_static(static_first(x), static_first(y))
lst = _try_static(static_last(x), static_last(y))
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
end

95 changes: 95 additions & 0 deletions src/static.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@

"""
A statically sized `Int`.
Use `Static(N)` instead of `Val(N)` when you want it to behave like a number.
"""
struct Static{N} <: Integer
Static{N}() where {N} = new{N::Int}()
end
Base.@pure Static(N::Int) = Static{N}()
Static(N::Integer) = Static(convert(Int, N))
Static(::Static{N}) where {N} = Static{N}()
Static(::Val{N}) where {N} = Static{N}()
Base.Val(::Static{N}) where {N} = Val{N}()
Base.convert(::Type{T}, ::Static{N}) where {T<:Number,N} = convert(T, N)
Base.convert(::Type{Static{N}}, ::Static{N}) where {N} = Static{N}()
# for S ∈ [:Any, :AbstractIrrational]#, :(Complex{<:Real})]
# let S = :Any
let S = :AbstractIrrational
@eval begin
Base.promote_rule(::Type{<:Static}, ::Type{T}) where {T <: $S} = promote_rule(Int, T)
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T <: $S} = promote_rule(T, Int)
end
end
for (S,T) ∈ [(:Complex,:Real), (:Rational, :Integer), (:(Base.TwicePrecision),:Any)]
@eval Base.promote_rule(::Type{$S{T}}, ::Type{<:Static}) where {T <: $T} = promote_rule($S{T}, Int)
end
Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:Static}) = Union{Nothing, Missing, Int}
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Union{Missing,Nothing}} = promote_rule(T, Int)
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Nothing} = promote_rule(T, Int)
Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Missing} = promote_rule(T, Int)
for T ∈ [:Bool, :Missing, :BigFloat, :BigInt, :Nothing, :Any]
# let S = :Any
@eval begin
Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: Static} = promote_rule(Int, $T)
Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: Static} = promote_rule($T, Int)
end
end
Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int
Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N

Base.iszero(::Static{0}) = true
Base.iszero(::Static) = false
Base.isone(::Static{1}) = true
Base.isone(::Static) = false

for T = [:Real, :Rational, :Integer]
@eval begin
@inline Base.:(+)(i::$T, ::Static{0}) = i
@inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M
@inline Base.:(+)(::Static{0}, i::$T) = i
@inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i
@inline Base.:(-)(i::$T, ::Static{0}) = i
@inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M
@inline Base.:(*)(i::$T, ::Static{0}) = Static{0}()
@inline Base.:(*)(i::$T, ::Static{1}) = i
@inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M
@inline Base.:(*)(::Static{0}, i::$T) = Static{0}()
@inline Base.:(*)(::Static{1}, i::$T) = i
@inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i
end
end
@inline Base.:(+)(::Static{0}, ::Static{0}) = Static{0}()
@inline Base.:(+)(::Static{0}, ::Static{M}) where {M} = Static{M}()
@inline Base.:(+)(::Static{M}, ::Static{0}) where {M} = Static{M}()

@inline Base.:(-)(::Static{M}, ::Static{0}) where {M} = Static{M}()

@inline Base.:(*)(::Static{0}, ::Static{0}) = Static{0}()
@inline Base.:(*)(::Static{1}, ::Static{0}) = Static{0}()
@inline Base.:(*)(::Static{0}, ::Static{1}) = Static{0}()
@inline Base.:(*)(::Static{1}, ::Static{1}) = Static{1}()
@inline Base.:(*)(::Static{M}, ::Static{0}) where {M} = Static{0}()
@inline Base.:(*)(::Static{0}, ::Static{M}) where {M} = Static{0}()
@inline Base.:(*)(::Static{M}, ::Static{1}) where {M} = Static{M}()
@inline Base.:(*)(::Static{1}, ::Static{M}) where {M} = Static{M}()
for f ∈ [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :(⊻)]
@eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N)))
end
for f ∈ [:(==), :(!=), :(<), :(≤), :(>), :(≥)]
@eval begin
@inline Base.$f(::Static{M}, ::Static{N}) where {M,N} = $f(M, N)
@inline Base.$f(::Static{M}, x::Int) where {M} = $f(M, x)
@inline Base.$f(x::Int, ::Static{M}) where {M} = $f(x, M)
end
end

@inline function maybe_static(f::F, g::G, x) where {F, G}
L = f(x)
isnothing(L) ? g(x) : Static(L)
end
@inline static_length(x) = maybe_static(known_length, length, x)
@inline static_first(x) = maybe_static(known_first, first, x)
@inline static_last(x) = maybe_static(known_last, last, x)
@inline static_step(x) = maybe_static(known_step, step, x)

58 changes: 50 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using ArrayInterface, Test
using Base: setindex
import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance
import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, Static
@test ArrayInterface.ismutable(rand(3))

using Aqua
Aqua.test_all(ArrayInterface)

using StaticArrays
x = @SVector [1,2,3]
@test ArrayInterface.ismutable(x) == false
Expand Down Expand Up @@ -220,12 +223,51 @@ end
end

@testset "indices" begin
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)))) == 1:6
@test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6
@test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)), (1, 2))) == 1:2
@test @inferred(ArrayInterface.indices((ones(2, 3), ones(2, 3)), 1)) == 1:2
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), 1)
@test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), (1, 2))
A23 = ones(2,3); SA23 = @SMatrix ones(2,3);
A32 = ones(3,2); SA32 = @SMatrix ones(3,2);
@test @inferred(ArrayInterface.indices((A23, A32))) == 1:6
@test @inferred(ArrayInterface.indices((SA23, A32))) == 1:6
@test @inferred(ArrayInterface.indices((A23, SA32))) == 1:6
@test @inferred(ArrayInterface.indices((SA23, SA32))) == 1:6
@test @inferred(ArrayInterface.indices(A23)) == 1:6
@test @inferred(ArrayInterface.indices(SA23)) == 1:6
@test @inferred(ArrayInterface.indices(A23, 1)) == 1:2
@test @inferred(ArrayInterface.indices(SA23, Static(1))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((A23, A32), (1, 2))) == 1:2
@test @inferred(ArrayInterface.indices((SA23, A32), (Static(1), 2))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((A23, SA32), (1, Static(2)))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((SA23, SA32), (Static(1), Static(2)))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((A23, A23), 1)) == 1:2
@test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((SA23, A23), Static(1))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((A23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
@test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2))
@test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), 1)
@test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), (1, 2))
@test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), Static(1))
@test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), (Static(1), 2))
@test_throws AssertionError ArrayInterface.indices((SA23, SA23), (Static(1), Static(2)))
end

@testset "Static" begin
@test iszero(Static(0))
@test !iszero(Static(1))
# test for ambiguities and correctness
for i ∈ [Static(0), Static(1), Static(2), 3]
for j ∈ [Static(0), Static(1), Static(2), 3]
i === j === 3 && continue
for f ∈ [+, -, *, ÷, %, <<, >>, >>>, &, |, ⊻, ==, ≤, ≥]
(iszero(j) && ((f === ÷) || (f === %))) && continue # integer division error
@test convert(Int, @inferred(f(i,j))) == f(convert(Int, i), convert(Int, j))
end
end
i == 3 && break
for f ∈ [+, -, *, /, ÷, %, ==, ≤, ≥]
x = f(convert(Int, i), 1.4)
y = f(1.4, convert(Int, i))
@test convert(typeof(x), @inferred(f(i, 1.4))) === x
@test convert(typeof(y), @inferred(f(1.4, i))) === y # if f is division and i === Static(0), returns `NaN`; hence use of ==== in check.
end
end
end