Skip to content

Commit

Permalink
Merge pull request #68 from Tokazama/master
Browse files Browse the repository at this point in the history
Non-mutating versions of pop, popfirst, etc. (#66)
  • Loading branch information
chriselrod authored Sep 14, 2020
2 parents 85e93de + bc3447d commit e49ec67
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 62 deletions.
99 changes: 98 additions & 1 deletion src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Requires
using LinearAlgebra
using SparseArrays

using Base: OneTo
using Base: OneTo, @propagate_inbounds

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
Expand Down Expand Up @@ -543,6 +543,103 @@ function restructure(x::Array,y)
reshape(convert(Array,y),size(x)...)
end

"""
insert(collection, index, item)
Return a new instance of `collection` with `item` inserted into at the given `index`.
"""
Base.@propagate_inbounds function insert(collection, index, item)
@boundscheck checkbounds(collection, index)
ret = similar(collection, length(collection) + 1)
@inbounds for i in firstindex(ret):(index - 1)
ret[i] = collection[i]
end
@inbounds ret[index] = item
@inbounds for i in (index + 1):lastindex(ret)
ret[i] = collection[i - 1]
end
return ret
end

function insert(x::Tuple, index::Integer, item)
@boundscheck if !checkindex(Bool, static_first(x):static_last(x), index)
throw(BoundsError(x, index))
end
return unsafe_insert(x, Int(index), item)
end

@inline function unsafe_insert(x::Tuple, i::Int, item)
if i === 1
return (item, x...)
else
return (first(x), unsafe_insert(Base.tail(x), i - 1, item)...)
end
end

"""
deleteat(collection, index)
Return a new instance of `collection` with the item at the given `index` removed.
"""
@propagate_inbounds function deleteat(collection::AbstractVector, index)
@boundscheck if !checkindex(Bool, eachindex(collection), index)
throw(BoundsError(collection, index))
end
return unsafe_deleteat(collection, index)
end
@propagate_inbounds function deleteat(collection::Tuple, index)
@boundscheck if !checkindex(Bool, static_first(collection):static_last(collection), index)
throw(BoundsError(collection, index))
end
return unsafe_deleteat(collection, index)
end

function unsafe_deleteat(src::AbstractVector, index::Integer)
dst = similar(src, length(src) - 1)
@inbounds for i in indices(dst)
if i < index
dst[i] = src[i]
else
dst[i] = src[i + 1]
end
end
return dst
end

@inline function unsafe_deleteat(src::AbstractVector, inds::AbstractVector)
dst = similar(src, length(src) - length(inds))
dst_index = firstindex(dst)
@inbounds for src_index in indices(src)
if !in(src_index, inds)
dst[dst_index] = src[src_index]
dst_index += one(dst_index)
end
end
return dst
end

@inline function unsafe_deleteat(src::Tuple, inds::AbstractVector)
dst = Vector{eltype(src)}(undef, length(src) - length(inds))
dst_index = firstindex(dst)
@inbounds for src_index in OneTo(length(src))
if !in(src_index, inds)
dst[dst_index] = src[src_index]
dst_index += one(dst_index)
end
end
return Tuple(dst)
end

@inline function unsafe_deleteat(x::Tuple, i::Integer)
if i === one(i)
return Base.tail(x)
elseif i == length(x)
return Base.front(x)
else
return (first(x), unsafe_deleteat(Base.tail(x), i - one(i))...)
end
end

function __init__()

@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
Expand Down
61 changes: 21 additions & 40 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)

# add methods to support ArrayInterface

_get(x) = x
_get(::Static{V}) where {V} = V
_get(::Type{Static{V}}) where {V} = V
_convert(::Type{T}, x) where {T} = convert(T, x)
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))

"""
OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T}
Expand All @@ -57,28 +51,23 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in
from other valid indices. Therefore, users should not expect the same checks are used
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
"""
struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T}
struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRange{Int}
start::F
stop::L

function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
if _get(start) isa T
if _get(stop) isa T
return new{T,typeof(start),typeof(stop)}(start, stop)
function OptionallyStaticUnitRange(start, stop)
if eltype(start) <: Int
if eltype(stop) <: Int
return new{typeof(start),typeof(stop)}(start, stop)
else
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
return OptionallyStaticUnitRange(start, Int(stop))
end
else
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
return OptionallyStaticUnitRange(Int(start), stop)
end
end

function OptionallyStaticUnitRange(start, stop)
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
return OptionallyStaticUnitRange{T}(start, stop)
end

function OptionallyStaticUnitRange(x::AbstractRange)
function OptionallyStaticUnitRange(x::AbstractRange)
if step(x) == 1
fst = static_first(x)
lst = static_last(x)
Expand All @@ -94,12 +83,12 @@ Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(
Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U))

Base.first(r::OptionallyStaticUnitRange) = r.start
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
Base.step(::OptionallyStaticUnitRange) = Static(1)
Base.last(r::OptionallyStaticUnitRange) = r.stop

known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L
known_first(::Type{<:OptionallyStaticUnitRange{Static{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange}) = 1
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,Static{L}}}) where {L} = L

function Base.isempty(r::OptionallyStaticUnitRange)
if known_first(r) === oneunit(eltype(r))
Expand All @@ -112,10 +101,8 @@ end
unsafe_isempty_one_to(lst) = lst <= zero(lst)
unsafe_isempty_unit_range(fst, lst) = fst > lst

unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T))

unsafe_length_one_to(lst::T) where {T<:Int} = T(lst)
unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst))
unsafe_length_one_to(lst::Int) = lst
unsafe_length_one_to(::Static{L}) where {L} = lst

Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
if known_first(r) === oneunit(r)
Expand Down Expand Up @@ -144,15 +131,15 @@ end
@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()"
function _try_static(::Static{N}, x) where {N}
@assert N == x "Unequal Indices: Static{$N}() != x == $x"
Static{N}()
return Static{N}()
end
function _try_static(x, ::Static{N}) where {N}
@assert N == x "Unequal Indices: x == $x != Static{$N}()"
Static{N}()
return Static{N}()
end
function _try_static(x, y)
@assert x == y "Unequal Indicess: x == $x != $y == y"
x
return x
end

###
Expand All @@ -172,24 +159,19 @@ end
end
end

function Base.length(r::OptionallyStaticUnitRange{T}) where {T}
function Base.length(r::OptionallyStaticUnitRange)
if isempty(r)
return zero(T)
return 0
else
if known_one(r) === one(T)
if known_first(r) === 0
return unsafe_length_one_to(last(r))
else
return unsafe_length_unit_range(first(r), last(r))
end
end
end

function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}}
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
end
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}}
return Base.checked_add(lst - fst, one(T))
end
unsafe_length_unit_range(start::Integer, stop::Integer) = Int(start - stop + 1)

"""
indices(x[, d])
Expand Down Expand Up @@ -231,4 +213,3 @@ end
lst = _try_static(static_last(x), static_last(y))
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
end

49 changes: 28 additions & 21 deletions src/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ Use `Static(N)` instead of `Val(N)` when you want it to behave like a number.
struct Static{N} <: Integer
Static{N}() where {N} = new{N::Int}()
end

const Zero = Static{0}
const One = Static{1}

Base.@pure Static(N::Int) = Static{N}()
Static(N::Integer) = Static(convert(Int, N))
Static(::Static{N}) where {N} = Static{N}()
Expand Down Expand Up @@ -33,41 +37,44 @@ end
Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int
Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N

Base.iszero(::Static{0}) = true
Base.eltype(::Type{T}) where {T<:Static} = Int
Base.iszero(::Zero) = true
Base.iszero(::Static) = false
Base.isone(::Static{1}) = true
Base.isone(::One) = true
Base.isone(::Static) = false
Base.zero(::Type{T}) where {T<:Static} = Zero()
Base.one(::Type{T}) where {T<:Static} = One()

for T = [:Real, :Rational, :Integer]
@eval begin
@inline Base.:(+)(i::$T, ::Static{0}) = i
@inline Base.:(+)(i::$T, ::Zero) = i
@inline Base.:(+)(i::$T, ::Static{M}) where {M} = i + M
@inline Base.:(+)(::Static{0}, i::$T) = i
@inline Base.:(+)(::Zero, i::$T) = i
@inline Base.:(+)(::Static{M}, i::$T) where {M} = M + i
@inline Base.:(-)(i::$T, ::Static{0}) = i
@inline Base.:(-)(i::$T, ::Zero) = i
@inline Base.:(-)(i::$T, ::Static{M}) where {M} = i - M
@inline Base.:(*)(i::$T, ::Static{0}) = Static{0}()
@inline Base.:(*)(i::$T, ::Static{1}) = i
@inline Base.:(*)(i::$T, ::Zero) = Zero()
@inline Base.:(*)(i::$T, ::One) = i
@inline Base.:(*)(i::$T, ::Static{M}) where {M} = i * M
@inline Base.:(*)(::Static{0}, i::$T) = Static{0}()
@inline Base.:(*)(::Static{1}, i::$T) = i
@inline Base.:(*)(::Zero, i::$T) = Zero()
@inline Base.:(*)(::One, i::$T) = i
@inline Base.:(*)(::Static{M}, i::$T) where {M} = M * i
end
end
@inline Base.:(+)(::Static{0}, ::Static{0}) = Static{0}()
@inline Base.:(+)(::Static{0}, ::Static{M}) where {M} = Static{M}()
@inline Base.:(+)(::Static{M}, ::Static{0}) where {M} = Static{M}()
@inline Base.:(+)(::Zero, ::Zero) = Zero()
@inline Base.:(+)(::Zero, ::Static{M}) where {M} = Static{M}()
@inline Base.:(+)(::Static{M}, ::Zero) where {M} = Static{M}()

@inline Base.:(-)(::Static{M}, ::Static{0}) where {M} = Static{M}()
@inline Base.:(-)(::Static{M}, ::Zero) where {M} = Static{M}()

@inline Base.:(*)(::Static{0}, ::Static{0}) = Static{0}()
@inline Base.:(*)(::Static{1}, ::Static{0}) = Static{0}()
@inline Base.:(*)(::Static{0}, ::Static{1}) = Static{0}()
@inline Base.:(*)(::Static{1}, ::Static{1}) = Static{1}()
@inline Base.:(*)(::Static{M}, ::Static{0}) where {M} = Static{0}()
@inline Base.:(*)(::Static{0}, ::Static{M}) where {M} = Static{0}()
@inline Base.:(*)(::Static{M}, ::Static{1}) where {M} = Static{M}()
@inline Base.:(*)(::Static{1}, ::Static{M}) where {M} = Static{M}()
@inline Base.:(*)(::Zero, ::Zero) = Zero()
@inline Base.:(*)(::One, ::Zero) = Zero()
@inline Base.:(*)(::Zero, ::One) = Zero()
@inline Base.:(*)(::One, ::One) = One()
@inline Base.:(*)(::Static{M}, ::Zero) where {M} = Zero()
@inline Base.:(*)(::Zero, ::Static{M}) where {M} = Zero()
@inline Base.:(*)(::Static{M}, ::One) where {M} = Static{M}()
@inline Base.:(*)(::One, ::Static{M}) where {M} = Static{M}()
for f [:(+), :(-), :(*), :(/), :(÷), :(%), :(<<), :(>>), :(>>>), :(&), :(|), :()]
@eval @generated Base.$f(::Static{M}, ::Static{N}) where {M,N} = Expr(:call, Expr(:curly, :Static, $f(M, N)))
end
Expand Down
22 changes: 22 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ end
@testset "Static" begin
@test iszero(Static(0))
@test !iszero(Static(1))
@test @inferred(one(Static)) === Static(1)
@test @inferred(zero(Static)) === Static(0)
@test eltype(one(Static)) <: Int
# test for ambiguities and correctness
for i [Static(0), Static(1), Static(2), 3]
for j [Static(0), Static(1), Static(2), 3]
Expand All @@ -271,3 +274,22 @@ end
end
end

@testset "insert/deleteat" begin
@test @inferred(ArrayInterface.insert([1,2,3], 2, -2)) == [1, -2, 2, 3]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], 2)) == [1, 3]

@test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 2])) == [3]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 3])) == [2]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [2, 3])) == [1]


@test @inferred(ArrayInterface.insert((1,2,3), 1, -2)) == (-2, 1, 2, 3)
@test @inferred(ArrayInterface.insert((1,2,3), 2, -2)) == (1, -2, 2, 3)
@test @inferred(ArrayInterface.insert((1,2,3), 3, -2)) == (1, 2, -2, 3)

@test @inferred(ArrayInterface.deleteat((1, 2, 3), 1)) == (2, 3)
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 2)) == (1, 3)
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 3)) == (1, 2)
@test ArrayInterface.deleteat((1, 2, 3), [1, 2]) == (3,)
end

0 comments on commit e49ec67

Please sign in to comment.