diff --git a/Project.toml b/Project.toml index 4473467d6..d70b9896f 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Requires = "0.5, 1.0" julia = "1.2" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" @@ -21,4 +22,4 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random"] +test = ["Test", "LabelledArrays", "StaticArrays", "BandedMatrices", "BlockBandedMatrices", "SuiteSparse", "Random", "Aqua"] diff --git a/README.md b/README.md index 5fcc56f64..ca1e862f9 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,14 @@ Otherwise, returns `nothing`. For example, `known_step(UnitRange{Int})` returns If `length` of an instance of type `T` is known at compile time, return it. Otherwise, return `nothing`. +## Static(N::Int) + +Creates a static integer with value known at compile time. It is a number, +supporting basic arithmetic. Many operations with two `Static` integers +will produce another `Static` integer. If one of the arguments to a +function call isn't static (e.g., `Static(4) + 3`) then the `Static` +number will promote to a dynamic value. + # List of things to add - https://github.com/JuliaLang/julia/issues/22216 diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index d666ac594..cb6aab1fa 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -699,6 +699,7 @@ function __init__() end end +include("static.jl") include("ranges.jl") end diff --git a/src/ranges.jl b/src/ranges.jl index e47be321f..601c8455e 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -43,7 +43,8 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T) # add methods to support ArrayInterface _get(x) = x -_get(::Val{V}) where {V} = V +_get(::Static{V}) where {V} = V +_get(::Type{Static{V}}) where {V} = V _convert(::Type{T}, x) where {T} = convert(T, x) _convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V)) @@ -56,7 +57,7 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in from other valid indices. Therefore, users should not expect the same checks are used to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`. """ -struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T} +struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T} start::F stop::L @@ -79,10 +80,8 @@ struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T} function OptionallyStaticUnitRange(x::AbstractRange) if step(x) == 1 - fst = known_first(x) - fst = fst === nothing ? first(x) : Val(fst) - lst = known_last(x) - lst = lst === nothing ? last(x) : Val(lst) + fst = static_first(x) + lst = static_last(x) return OptionallyStaticUnitRange(fst, lst) else throw(ArgumentError("step must be 1, got $(step(r))")) @@ -90,17 +89,17 @@ struct OptionallyStaticUnitRange{T,F,L} <: AbstractUnitRange{T} end end -Base.first(r::OptionallyStaticUnitRange{<:Any,Val{F}}) where {F} = F -Base.first(r::OptionallyStaticUnitRange{<:Any,<:Any}) = r.start +Base.:(:)(L::Integer, ::Static{U}) where {U} = OptionallyStaticUnitRange(L, Static(U)) +Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(L), U) +Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U)) +Base.first(r::OptionallyStaticUnitRange) = r.start Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T) +Base.last(r::OptionallyStaticUnitRange) = r.stop -Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}) where {L} = L -Base.last(r::OptionallyStaticUnitRange{<:Any,<:Any,<:Any}) = r.stop - -known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Val{F}}}) where {F} = F +known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T) -known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Val{L}}}) where {L} = L +known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L function Base.isempty(r::OptionallyStaticUnitRange) if known_first(r) === oneunit(eltype(r)) @@ -141,10 +140,20 @@ end return convert(eltype(r), val) end -_try_static(x, y) = Val(x) -_try_static(::Nothing, y) = Val(y) -_try_static(x, ::Nothing) = Val(x) -_try_static(::Nothing, ::Nothing) = nothing +@inline _try_static(::Static{N}, ::Static{N}) where {N} = Static{N}() +@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()" +function _try_static(::Static{N}, x) where {N} + @assert N == x "Unequal Indices: Static{$N}() != x == $x" + Static{N}() +end +function _try_static(x, ::Static{N}) where {N} + @assert N == x "Unequal Indices: x == $x != Static{$N}()" + Static{N}() +end +function _try_static(x, y) + @assert x == y "Unequal Indicess: x == $x != $y == y" + x +end ### ### length @@ -193,7 +202,7 @@ specified then indices for visiting each index of `x` is returned. """ @inline function indices(x) inds = eachindex(x) - if inds isa AbstractUnitRange{<:Integer} + if inds isa AbstractUnitRange && eltype(inds) <: Integer return Base.Slice(OptionallyStaticUnitRange(inds)) else return inds @@ -202,30 +211,24 @@ end function indices(x::Tuple) inds = map(eachindex, x) - @assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds" return reduce(_pick_range, inds) end -indices(x, d) = indices(axes(x, d)) +@inline indices(x, d) = indices(axes(x, d)) -@inline function indices(x::NTuple{N,<:Any}, dim) where {N} +@inline function indices(x::Tuple{Vararg{Any,N}}, dim) where {N} inds = map(x_i -> indices(x_i, dim), x) - @assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds" return reduce(_pick_range, inds) end -@inline function indices(x::NTuple{N,<:Any}, dim::NTuple{N,<:Any}) where {N} +@inline function indices(x::Tuple{Vararg{Any,N}}, dim::Tuple{Vararg{Any,N}}) where {N} inds = map(indices, x, dim) - @assert all(isequal(first(inds)), Base.tail(inds)) "Not all specified axes are equal: $inds" return reduce(_pick_range, inds) end @inline function _pick_range(x, y) - fst = _try_static(known_first(x), known_first(y)) - fst = fst === nothing ? first(x) : fst - - lst = _try_static(known_last(x), known_last(y)) - lst = lst === nothing ? last(x) : lst + fst = _try_static(static_first(x), static_first(y)) + lst = _try_static(static_last(x), static_last(y)) return Base.Slice(OptionallyStaticUnitRange(fst, lst)) end diff --git a/src/static.jl b/src/static.jl new file mode 100644 index 000000000..6a90cecd5 --- /dev/null +++ b/src/static.jl @@ -0,0 +1,90 @@ + +""" +A statically sized `Int`. +Use `Static(N)` instead of `Val(N)` when you want it to behave like a number. +""" +struct Static{N} <: Integer + Static{N}() where {N} = new{N::Int}() +end +Base.@pure Static(N::Int) = Static{N}() +Static(N::Integer) = Static(convert(Int, N)) +Static(::Static{N}) where {N} = Static{N}() +Static(::Val{N}) where {N} = Static{N}() +Base.Val(::Static{N}) where {N} = Val{N}() +Base.convert(::Type{T}, ::Static{N}) where {T<:Number,N} = convert(T, N) +Base.convert(::Type{Static{N}}, ::Static{N}) where {N} = Static{N}() + +Base.promote_rule(::Type{<:Static}, ::Type{T}) where {T <: AbstractIrrational} = promote_rule(Int, T) +Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T <: AbstractIrrational} = promote_rule(T, Int) +for (S,T) ∈ [(:Complex,:Real), (:Rational, :Integer), (:(Base.TwicePrecision),:Any)] + @eval Base.promote_rule(::Type{$S{T}}, ::Type{<:Static}) where {T <: $T} = promote_rule($S{T}, Int) +end +Base.promote_rule(::Type{Union{Nothing,Missing}}, ::Type{<:Static}) = Union{Nothing, Missing, Int} +Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Union{Missing,Nothing}} = promote_rule(T, Int) +Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Nothing} = promote_rule(T, Int) +Base.promote_rule(::Type{T}, ::Type{<:Static}) where {T >: Missing} = promote_rule(T, Int) +for T ∈ [:Bool, :Missing, :BigFloat, :BigInt, :Nothing, :Any] +# let S = :Any + @eval begin + Base.promote_rule(::Type{S}, ::Type{$T}) where {S <: Static} = promote_rule(Int, $T) + Base.promote_rule(::Type{$T}, ::Type{S}) where {S <: Static} = promote_rule($T, Int) + end +end +Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int +Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N + +Base.iszero(::Static{0}) = true +Base.iszero(::Static) = false +Base.isone(::Static{1}) = true +Base.isone(::Static) = false + +for T = [:Real, :Rational, :Integer] + @eval begin + @inline Base.:(+)(i::$T, ::Static{0}) = i + @inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M + @inline Base.:(+)(::Static{0}, i::$T) = i + @inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i + @inline Base.:(-)(i::$T, ::Static{0}) = i + @inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M + @inline Base.:(*)(i::$T, ::Static{0}) = Static{0}() + @inline Base.:(*)(i::$T, ::Static{1}) = i + @inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M + @inline Base.:(*)(::Static{0}, i::$T) = Static{0}() + @inline Base.:(*)(::Static{1}, i::$T) = i + @inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i + end +end +@inline Base.:(+)(::Static{0}, ::Static{0}) = Static{0}() +@inline Base.:(+)(::Static{0}, ::Static{M}) where {M} = Static{M}() +@inline Base.:(+)(::Static{M}, ::Static{0}) where {M} = Static{M}() + +@inline Base.:(-)(::Static{M}, ::Static{0}) where {M} = Static{M}() + +@inline Base.:(*)(::Static{0}, ::Static{0}) = Static{0}() +@inline Base.:(*)(::Static{1}, ::Static{0}) = Static{0}() +@inline Base.:(*)(::Static{0}, ::Static{1}) = Static{0}() +@inline Base.:(*)(::Static{1}, ::Static{1}) = Static{1}() +@inline Base.:(*)(::Static{M}, ::Static{0}) where {M} = Static{0}() +@inline Base.:(*)(::Static{0}, ::Static{M}) where {M} = Static{0}() +@inline Base.:(*)(::Static{M}, ::Static{1}) where {M} = Static{M}() +@inline Base.:(*)(::Static{1}, ::Static{M}) where {M} = Static{M}() +for f ∈ [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :(⊻)] + @eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N))) +end +for f ∈ [:(==), :(!=), :(<), :(≤), :(>), :(≥)] + @eval begin + @inline Base.$f(::Static{M}, ::Static{N}) where {M,N} = $f(M, N) + @inline Base.$f(::Static{M}, x::Int) where {M} = $f(M, x) + @inline Base.$f(x::Int, ::Static{M}) where {M} = $f(x, M) + end +end + +@inline function maybe_static(f::F, g::G, x) where {F, G} + L = f(x) + isnothing(L) ? g(x) : Static(L) +end +@inline static_length(x) = maybe_static(known_length, length, x) +@inline static_first(x) = maybe_static(known_first, first, x) +@inline static_last(x) = maybe_static(known_last, last, x) +@inline static_step(x) = maybe_static(known_step, step, x) + diff --git a/test/runtests.jl b/test/runtests.jl index 4a6902341..7e3e839f0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,11 @@ using ArrayInterface, Test using Base: setindex -import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance +import ArrayInterface: has_sparsestruct, findstructralnz, fast_scalar_indexing, lu_instance, Static @test ArrayInterface.ismutable(rand(3)) +using Aqua +Aqua.test_all(ArrayInterface) + using StaticArrays x = @SVector [1,2,3] @test ArrayInterface.ismutable(x) == false @@ -220,12 +223,51 @@ end end @testset "indices" begin - @test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)))) == 1:6 - @test @inferred(ArrayInterface.indices(ones(2, 3))) == 1:6 - @test @inferred(ArrayInterface.indices(ones(2, 3), 1)) == 1:2 - @test @inferred(ArrayInterface.indices((ones(2, 3), ones(3, 2)), (1, 2))) == 1:2 - @test @inferred(ArrayInterface.indices((ones(2, 3), ones(2, 3)), 1)) == 1:2 - @test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), 1) - @test_throws AssertionError ArrayInterface.indices((ones(2, 3), ones(3, 3)), (1, 2)) + A23 = ones(2,3); SA23 = @SMatrix ones(2,3); + A32 = ones(3,2); SA32 = @SMatrix ones(3,2); + @test @inferred(ArrayInterface.indices((A23, A32))) == 1:6 + @test @inferred(ArrayInterface.indices((SA23, A32))) == 1:6 + @test @inferred(ArrayInterface.indices((A23, SA32))) == 1:6 + @test @inferred(ArrayInterface.indices((SA23, SA32))) == 1:6 + @test @inferred(ArrayInterface.indices(A23)) == 1:6 + @test @inferred(ArrayInterface.indices(SA23)) == 1:6 + @test @inferred(ArrayInterface.indices(A23, 1)) == 1:2 + @test @inferred(ArrayInterface.indices(SA23, Static(1))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((A23, A32), (1, 2))) == 1:2 + @test @inferred(ArrayInterface.indices((SA23, A32), (Static(1), 2))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((A23, SA32), (1, Static(2)))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((SA23, SA32), (Static(1), Static(2)))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((A23, A23), 1)) == 1:2 + @test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((SA23, A23), Static(1))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((A23, SA23), Static(1))) === Base.Slice(Static(1):Static(2)) + @test @inferred(ArrayInterface.indices((SA23, SA23), Static(1))) === Base.Slice(Static(1):Static(2)) + @test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), 1) + @test_throws AssertionError ArrayInterface.indices((A23, ones(3, 3)), (1, 2)) + @test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), Static(1)) + @test_throws AssertionError ArrayInterface.indices((SA23, ones(3, 3)), (Static(1), 2)) + @test_throws AssertionError ArrayInterface.indices((SA23, SA23), (Static(1), Static(2))) +end + +@testset "Static" begin + @test iszero(Static(0)) + @test !iszero(Static(1)) + # test for ambiguities and correctness + for i ∈ [Static(0), Static(1), Static(2), 3] + for j ∈ [Static(0), Static(1), Static(2), 3] + i === j === 3 && continue + for f ∈ [+, -, *, ÷, %, <<, >>, >>>, &, |, ⊻, ==, ≤, ≥] + (iszero(j) && ((f === ÷) || (f === %))) && continue # integer division error + @test convert(Int, @inferred(f(i,j))) == f(convert(Int, i), convert(Int, j)) + end + end + i == 3 && break + for f ∈ [+, -, *, /, ÷, %, ==, ≤, ≥] + x = f(convert(Int, i), 1.4) + y = f(1.4, convert(Int, i)) + @test convert(typeof(x), @inferred(f(i, 1.4))) === x + @test convert(typeof(y), @inferred(f(1.4, i))) === y # if f is division and i === Static(0), returns `NaN`; hence use of ==== in check. + end + end end